fix def use search

This commit is contained in:
2025-02-25 17:12:21 +03:00
parent 859d70daa7
commit 6671780a97
2 changed files with 47 additions and 36 deletions

View File

@@ -123,10 +123,8 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
}
auto operation = instruction->getInstruction()->getOperation();
auto type = instruction->getInstruction()->getArg1()->getType();
if ((operation == SAPFOR::CFG_OP::STORE && type == SAPFOR::CFG_ARG_TYPE::ARRAY) ||
(operation == SAPFOR::CFG_OP::LOAD && type == SAPFOR::CFG_ARG_TYPE::ARRAY))
if ((operation == SAPFOR::CFG_OP::STORE || operation == SAPFOR::CFG_OP::LOAD) && type == SAPFOR::CFG_ARG_TYPE::ARRAY)
{
vector<SAPFOR::Argument*> index_vars;
vector<int> refPos;
string array_name;
@@ -147,7 +145,8 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
}
/*to choose correct dimension*/
int n = index_vars.size();
if (operation == SAPFOR::CFG_OP::STORE)
vector<ArrayDimension> accessPoint(n);
/*if (operation == SAPFOR::CFG_OP::STORE)
{
if (def[array_name].empty())
{
@@ -160,7 +159,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
{
use[array_name].resize(n);
}
}
}*/
SgArrayRefExp* ref = (SgArrayRefExp*)instruction->getInstruction()->getExpression();
vector<pair<int, int>> coefsForDims;
@@ -214,18 +213,27 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
uint64_t step = currentCoefs.first;
current_dim = { start, step, (uint64_t)currentLoop->calculatedCountOfIters };
}
if (operation == SAPFOR::CFG_OP::STORE)
/*if (operation == SAPFOR::CFG_OP::STORE)
{
def[array_name][n - index_vars.size()].push_back(current_dim);
}
else
{
use[array_name][n - index_vars.size()].push_back(current_dim);
}
}*/
accessPoint[n - index_vars.size()] = current_dim;
index_vars.pop_back();
refPos.pop_back();
coefsForDims.pop_back();
}
if (operation == SAPFOR::CFG_OP::STORE)
{
def[array_name].Insert(accessPoint);
}
else
{
use[array_name].Insert(accessPoint);
}
}
}
return 0;
@@ -407,8 +415,8 @@ static void ElementsUnion(const vector<ArrayDimension>& firstElement, const vect
rc = ElementsDifference(secondElement, intersection);
}
void AccessingSet::FindUncovered(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result) {
vector<vector<ArrayDimension>> result, newTails;
void AccessingSet::FindUncovered(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result) const{
vector<vector<ArrayDimension>> newTails;
result.push_back(element);
for(const auto& currentElement: allElements)
{
@@ -424,41 +432,38 @@ void AccessingSet::FindUncovered(const vector<ArrayDimension>& element, vector<v
}
}
bool AccessingSet::ContainsElement(const vector<ArrayDimension>& element)
bool AccessingSet::ContainsElement(const vector<ArrayDimension>& element) const
{
vector<vector<ArrayDimension>> tails;
FindUncovered(element, tails);
return !tails.empty();
}
void AccessingSet::FindCoveredBy(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result)
void AccessingSet::FindCoveredBy(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result) const
{
for(const auto& currentElement: allElements)
{
for(const auto& tailLoc: tails)
{
auto intersection = ElementsIntersection(tailLoc, currentElement);
if(!intersection.empty()) {
result.push_back(intersection);
}
auto intersection = ElementsIntersection(element, currentElement);
if(!intersection.empty()) {
result.push_back(intersection);
}
}
}
vector<vector<ArrayDimension>>> AccessingSet::GetElements()
vector<vector<ArrayDimension>> AccessingSet::GetElements() const
{
return AllElements;
return allElements;
}
void AccessingSet::Insert(const vector<ArrayDimension>& element)
{
vector<vector<ArrayDimension>> tails;
FindUncovered(element, tails);
AllElements.insert(AllElements.end(), tails.begin(), tails.end());
allElements.insert(allElements.end(), tails.begin(), tails.end());
}
void AccessingSet::Union(const AccessingSet& source) {
for(const auto element: source.GetElements()) {
for(auto& element: source.GetElements()) {
Insert(element);
}
}
@@ -466,16 +471,16 @@ void AccessingSet::Union(const AccessingSet& source) {
vector<vector<ArrayDimension>> AccessingSet::Intersect(const AccessingSet& secondSet)
{
vector<vector<ArrayDimension>> result;
for(const auto& element: AllElements)
for(const auto& element: allElements)
{
if(ContainsElement(secondSet), element)
if(secondSet.ContainsElement(element))
{
result.push_back(element)
result.push_back(element);
}
else
{
vector<vector<ArrayDimension> coveredBy;
FindCoveredBy(secondSet, element, coveredBy);
vector<vector<ArrayDimension>> coveredBy;
secondSet.FindCoveredBy(element, coveredBy);
if(!coveredBy.empty())
{
result.insert(result.end(), coveredBy.begin(), coveredBy.end());

View File

@@ -13,22 +13,28 @@ struct ArrayDimension
uint64_t start, step, tripCount;
};
typedef map<string, AccessingSet> ArrayAccessingIndexes;
class AccessingSet {
prinvate:
vector<vector<ArrayDimension>>> AllElements;
bool ContainsElement(const vector<ArrayDimension>& element);
void FindCoveredBy(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result);
void FindUncovered(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result);
private:
vector<vector<ArrayDimension>> allElements;
public:
AccessingSet(vector<vector<ArrayDimension>>> input): AllElements(input) {};
vector<vector<ArrayDimension>>> GetElements();
AccessingSet(vector<vector<ArrayDimension>> input) : allElements(input) {};
AccessingSet() {};
vector<vector<ArrayDimension>> GetElements() const;
void Insert(const vector<ArrayDimension>& element);
void Union(const AccessingSet& source);
vector<vector<ArrayDimension>> Intersect(const AccessingSet& secondSet);
bool ContainsElement(const vector<ArrayDimension>& element) const;
void FindCoveredBy(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result) const;
void FindUncovered(const vector<ArrayDimension>& element, vector<vector<ArrayDimension>>& result) const;
};
typedef map<string, AccessingSet> ArrayAccessingIndexes;
void FindPrivateArrays(map<string, vector<LoopGraph*>>& loopGraph, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR);
void GetDimensionInfo(LoopGraph* loop, map<DIST::Array*, vector<vector<ArrayDimension>>>& loopDimensionsInfo, int level);
set<SAPFOR::BasicBlock> GetBasicBlocksForLoop(LoopGraph* loop, vector<SAPFOR::BasicBlock>);