add ddot, change array propagation

This commit is contained in:
2026-03-12 04:25:45 +03:00
committed by ALEXks
parent 39abbafb3a
commit 97e60e16be
3 changed files with 175 additions and 139 deletions

View File

@@ -9,7 +9,10 @@
using namespace std; using namespace std;
static SgStatement* declPlace = NULL; SgStatement* declPlace = NULL;
unordered_set<SgStatement*> changed;
unordered_set<SgSymbol*> variablesToAdd;
unordered_set<SgStatement*> positionsToAdd;
static bool CheckConstIndexes(SgExpression* exp) static bool CheckConstIndexes(SgExpression* exp)
{ {
@@ -40,137 +43,117 @@ static SgExpression* CreateVar(int& variableNumber, SgType* type)
string name = varName + std::to_string(variableNumber) + "__"; string name = varName + std::to_string(variableNumber) + "__";
variableNumber++; variableNumber++;
SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *declPlace->controlParent()); SgStatement* funcStart = declPlace->controlParent();
SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *funcStart);
variablesToAdd.insert(varSymbol);
positionsToAdd.insert(declPlace);
return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr());
}
static SgStatement* FindLastDeclStatement(SgStatement* funcStart)
{
SgStatement* endSt = funcStart->lastNodeOfStmt();
SgStatement* cur = funcStart->lexNext();
SgStatement* lastDecl = funcStart;
const set<int> declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT,
EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL };
while (cur && cur != endSt)
{
if (cur->variant() == INTERFACE_STMT)
cur = cur->lastNodeOfStmt();
if (declVariants.find(cur->variant()) != declVariants.end())
lastDecl = cur;
else if (isSgExecutableStatement(cur))
break;
cur = cur->lexNext();
}
return lastDecl;
}
static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const unordered_set<SgSymbol*>& symbols)
{
if (symbols.empty())
return;
const string commonBlockName = "__propagation_common__"; const string commonBlockName = "__propagation_common__";
SgStatement* funcStart = declPlace->controlParent(); SgStatement* funcEnd = funcStart->lastNodeOfStmt();
SgStatement* commonStat = NULL; SgStatement* commonStat = NULL;
SgExpression* commonList = NULL; SgExpression* commonList = NULL;
SgStatement* funcEnd = funcStart->lastNodeOfStmt(); for (SgStatement* cur = funcStart->lexNext();
SgStatement* current = funcStart->lexNext(); cur && cur != funcEnd; cur = cur->lexNext())
while (current != funcEnd && current)
{ {
if (current->variant() == COMM_STAT) if (cur->variant() != COMM_STAT)
continue;
for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs())
{ {
for (SgExpression* exp = current->expr(0); exp; exp = exp->rhs()) if (exp->variant() != COMM_LIST)
continue;
const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL;
string existingName = id ? string(id) : string("spf_unnamed");
if (existingName == commonBlockName)
{ {
if (exp->variant() == COMM_LIST) commonStat = cur;
{ commonList = exp;
string existingName = exp->symbol() ?
string(exp->symbol()->identifier()) :
string("spf_unnamed");
if (existingName == commonBlockName)
{
commonStat = current;
commonList = exp;
break;
}
}
}
if (commonStat)
break; break;
}
} }
current = current->lexNext(); if (commonStat)
break;
} }
vector<SgExpression*> varRefs; vector<SgExpression*> varRefs;
if (commonList) for (SgSymbol* sym : symbols)
{ varRefs.push_back(new SgVarRefExp(sym));
SgExpression* varList = commonList->lhs(); SgExpression* varList = makeExprList(varRefs, false);
if (varList)
{
auto extractSymbol = [](SgExpression* exp) -> SgSymbol* {
if (!exp)
return NULL;
if (exp->symbol())
return exp->symbol();
if (exp->lhs() && exp->lhs()->symbol())
return exp->lhs()->symbol();
return NULL;
};
if (varList->variant() == EXPR_LIST)
{
for (SgExpression* exp = varList; exp; exp = exp->rhs())
{
SgExpression* varExp = exp->lhs();
SgSymbol* sym = extractSymbol(varExp);
if (sym)
{
varRefs.push_back(new SgVarRefExp(sym));
}
}
}
else
{
for (SgExpression* varExp = varList; varExp; varExp = varExp->rhs())
{
SgSymbol* sym = extractSymbol(varExp);
if (sym)
{
varRefs.push_back(new SgVarRefExp(sym));
}
}
}
}
}
if (!commonList) if (!commonList)
{ {
current = funcStart->lexNext();
while (current != funcEnd && current)
{
if (current->variant() == COMM_STAT)
{
commonStat = current;
break;
}
current = current->lexNext();
}
SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str());
commonList = new SgExpression(COMM_LIST, NULL, NULL, commonSymbol); commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol);
if (commonStat) commonStat = new SgStatement(COMM_STAT);
{ commonStat->setFileName(funcStart->fileName());
SgExpression* lastCommList = commonStat->expr(0); commonStat->setFileId(funcStart->getFileId());
if (lastCommList) commonStat->setProject(funcStart->getProject());
{ commonStat->setlineNumber(getNextNegativeLineNumber());
while (lastCommList->rhs()) commonStat->setExpression(0, commonList);
lastCommList = lastCommList->rhs();
lastCommList->setRhs(commonList);
}
else
{
commonStat->setExpression(0, commonList);
}
}
else
{
commonStat = new SgStatement(COMM_STAT);
commonStat->setFileName(declPlace->fileName());
commonStat->setFileId(declPlace->getFileId());
commonStat->setProject(declPlace->getProject());
commonStat->setlineNumber(getNextNegativeLineNumber());
commonStat->setExpression(0, commonList);
declPlace->insertStmtBefore(*commonStat, *declPlace->controlParent());
}
SgStatement* lastDecl = FindLastDeclStatement(funcStart);
lastDecl->insertStmtAfter(*commonStat, *funcStart);
} }
varRefs.push_back(new SgVarRefExp(varSymbol)); else
if (varRefs.size() > 0)
{ {
std::reverse(varRefs.begin(), varRefs.end());
SgExpression* varList = makeExprList(varRefs, false);
commonList->setLhs(varList); commonList->setLhs(varList);
} }
return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); SgStatement* insertAfter = commonStat;
for (SgSymbol* sym : symbols)
{
SgStatement* declStmt = sym->makeVarDeclStmt();
if (!declStmt)
continue;
if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt))
vds->setVariant(VAR_DECL_90);
declStmt->setFileName(funcStart->fileName());
declStmt->setFileId(funcStart->getFileId());
declStmt->setProject(funcStart->getProject());
declStmt->setlineNumber(getNextNegativeLineNumber());
insertAfter->insertStmtAfter(*declStmt, *funcStart);
insertAfter = declStmt;
}
} }
static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber) static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
@@ -227,6 +210,8 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map<
{ {
if (exp->symbol()->type()->variant() == T_STRING) if (exp->symbol()->type()->variant() == T_STRING)
return; return;
if (changed.find(st) != changed.end())
return;
string expUnparsed = exp->unparse(); string expUnparsed = exp->unparse();
if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType())
{ {
@@ -234,12 +219,62 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map<
} }
SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL); SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL);
newStatement->setFileId(st->getFileId()); newStatement->setFileId(st->getFileId());
newStatement->setProject(st->getProject()); newStatement->setProject(st->getProject());
newStatement->setlineNumber(getNextNegativeLineNumber()); newStatement->setlineNumber(getNextNegativeLineNumber());
newStatement->setLocalLineNumber(st->lineNumber()); newStatement->setLocalLineNumber(st->lineNumber());
st->insertStmtBefore(*newStatement, *st->controlParent()); st->insertStmtBefore(*newStatement, *st->controlParent());
changed.insert(st);
}
static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string array = exp->unparse();
arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType());
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
}
static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string varName = exp->unparse();
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol())
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
} }
void ArrayConstantPropagation(SgProject& project) void ArrayConstantPropagation(SgProject& project)
@@ -262,40 +297,37 @@ void ArrayConstantPropagation(SgProject& project)
for (; st != lastNode; st = st->lexNext()) for (; st != lastNode; st = st->lexNext())
{ {
if (st->variant() == ASSIGN_STAT) if (st->variant() == FOR_NODE)
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()))
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
else if (st->variant() == FOR_NODE)
{ {
SgExpression* lowerBound = st->expr(0)->lhs(); SgExpression* lowerBound = st->expr(0)->lhs();
SgExpression* upperBound = st->expr(0)->rhs(); SgExpression* upperBound = st->expr(0)->rhs();
string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse(); string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse();
if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs())) if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs()))
{ {
if (arrayToVariable.find(upperBoundUnparsed) == arrayToVariable.end()) TransformBorder(st, upperBound, arrayToVariable, variableNumber);
{
arrayToVariable[upperBoundUnparsed] = CreateVar(variableNumber, upperBound->symbol()->type()->baseType());
}
st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr()); st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr());
} }
else if (upperBound->variant() == VAR_REF)
CheckVariable(st, upperBound, arrayToVariable, variableNumber);
if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs())) if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs()))
{ {
if (arrayToVariable.find(lowerBoundUnparsed) == arrayToVariable.end()) TransformBorder(st, lowerBound, arrayToVariable, variableNumber);
{
arrayToVariable[lowerBoundUnparsed] = CreateVar(variableNumber, lowerBound->symbol()->type()->baseType());
}
st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr()); st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr());
} }
else if (lowerBound->variant() == VAR_REF)
CheckVariable(st, lowerBound, arrayToVariable, variableNumber);
} }
} }
} }
unordered_set<SgStatement*> funcStarts;
for (SgStatement* st : positionsToAdd)
{
SgStatement* scope = st->controlParent();
if (scope)
funcStarts.insert(scope);
}
for (const auto& st: funcStarts)
InsertCommonAndDeclsForFunction(st, variablesToAdd);
} }
} }

View File

@@ -33,10 +33,9 @@ static void RemoveEmptyPoints(ArrayAccessingIndexes& container)
points.push_back(arrayPoint); points.push_back(arrayPoint);
} }
if (points.size() < accessingSet.GetElements().size() && !points.empty()) if (!points.empty())
resultContainer[arrayName] = points; resultContainer[arrayName] = points;
else
if (points.empty())
toRemove.insert(arrayName); toRemove.insert(arrayName);
} }
@@ -281,7 +280,6 @@ static void SolveDataFlow(Region* DFG)
static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector<uint64_t>& declaredDims) static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector<uint64_t>& declaredDims)
{ {
declaredDims.clear();
if (!arrayRef || !arrayRef->symbol() || !isSgArrayType(arrayRef->symbol()->type())) if (!arrayRef || !arrayRef->symbol() || !isSgArrayType(arrayRef->symbol()->type()))
return false; return false;
SgArrayType* arrayType = (SgArrayType*)arrayRef->symbol()->type(); SgArrayType* arrayType = (SgArrayType*)arrayRef->symbol()->type();
@@ -290,17 +288,22 @@ static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector<uint64_t>
{ {
SgExpression* sizeExpr = arrayType->sizeInDim(i); SgExpression* sizeExpr = arrayType->sizeInDim(i);
SgConstantSymb* constValSymb = isSgConstantSymb(sizeExpr->symbol()); SgConstantSymb* constValSymb = isSgConstantSymb(sizeExpr->symbol());
string strDimLength; SgSubscriptExp* subscriptExpr = isSgSubscriptExp(sizeExpr);
uint64_t dimLength;
if (sizeExpr && sizeExpr->variant() == INT_VAL) if (sizeExpr && sizeExpr->variant() == INT_VAL)
strDimLength = sizeExpr->unparse(); dimLength = stol(sizeExpr->unparse());
else if (constValSymb) else if (constValSymb)
strDimLength = constValSymb->constantValue()->unparse(); dimLength = stol(constValSymb->constantValue()->unparse());
else if (subscriptExpr)
{
dimLength = stol(subscriptExpr->rhs()->unparse()) - stol(subscriptExpr->lhs()->unparse());
}
else else
return false; return false;
if (strDimLength == "0") if (dimLength == 0)
return false; return false;
declaredDims.push_back((uint64_t)stoi(strDimLength)); declaredDims.push_back(dimLength);
} }
return true; return true;
} }
@@ -313,7 +316,8 @@ static bool CheckDimensionLength(const AccessingSet& array)
SgArrayRefExp* arrayRef = array.GetElements()[0][0].array; SgArrayRefExp* arrayRef = array.GetElements()[0][0].array;
if (!arrayRef) if (!arrayRef)
return false; return false;
vector<uint64_t> declaredDims(dimCount); vector<uint64_t> declaredDims;
declaredDims.reserve(dimCount);
if (!getArrayDeclaredDimensions(arrayRef, declaredDims)) if (!getArrayDeclaredDimensions(arrayRef, declaredDims))
return false; return false;
vector<ArrayDimension> testArray(dimCount); vector<ArrayDimension> testArray(dimCount);

View File

@@ -204,7 +204,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
current_dim = { start, step, iters, ref }; current_dim = { start, step, iters, ref };
} }
if (current_dim.start != 0 && current_dim.step != 0 && current_dim.tripCount != 0) if (current_dim.step != 0 && current_dim.tripCount != 0)
{ {
accessPoint[n - index_vars.size()] = current_dim; accessPoint[n - index_vars.size()] = current_dim;
fillCount++; fillCount++;