add ddot, change array propagation

This commit is contained in:
2026-03-12 04:25:45 +03:00
committed by ALEXks
parent 39abbafb3a
commit 97e60e16be
3 changed files with 175 additions and 139 deletions

View File

@@ -9,7 +9,10 @@
using namespace std;
static SgStatement* declPlace = NULL;
SgStatement* declPlace = NULL;
unordered_set<SgStatement*> changed;
unordered_set<SgSymbol*> variablesToAdd;
unordered_set<SgStatement*> positionsToAdd;
static bool CheckConstIndexes(SgExpression* exp)
{
@@ -40,137 +43,117 @@ static SgExpression* CreateVar(int& variableNumber, SgType* type)
string name = varName + std::to_string(variableNumber) + "__";
variableNumber++;
SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *declPlace->controlParent());
SgStatement* funcStart = declPlace->controlParent();
SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *funcStart);
variablesToAdd.insert(varSymbol);
positionsToAdd.insert(declPlace);
return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr());
}
static SgStatement* FindLastDeclStatement(SgStatement* funcStart)
{
SgStatement* endSt = funcStart->lastNodeOfStmt();
SgStatement* cur = funcStart->lexNext();
SgStatement* lastDecl = funcStart;
const set<int> declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT,
EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL };
while (cur && cur != endSt)
{
if (cur->variant() == INTERFACE_STMT)
cur = cur->lastNodeOfStmt();
if (declVariants.find(cur->variant()) != declVariants.end())
lastDecl = cur;
else if (isSgExecutableStatement(cur))
break;
cur = cur->lexNext();
}
return lastDecl;
}
static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const unordered_set<SgSymbol*>& symbols)
{
if (symbols.empty())
return;
const string commonBlockName = "__propagation_common__";
SgStatement* funcStart = declPlace->controlParent();
SgStatement* funcEnd = funcStart->lastNodeOfStmt();
SgStatement* commonStat = NULL;
SgExpression* commonList = NULL;
SgStatement* funcEnd = funcStart->lastNodeOfStmt();
SgStatement* current = funcStart->lexNext();
while (current != funcEnd && current)
for (SgStatement* cur = funcStart->lexNext();
cur && cur != funcEnd; cur = cur->lexNext())
{
if (current->variant() == COMM_STAT)
if (cur->variant() != COMM_STAT)
continue;
for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs())
{
for (SgExpression* exp = current->expr(0); exp; exp = exp->rhs())
if (exp->variant() != COMM_LIST)
continue;
const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL;
string existingName = id ? string(id) : string("spf_unnamed");
if (existingName == commonBlockName)
{
if (exp->variant() == COMM_LIST)
{
string existingName = exp->symbol() ?
string(exp->symbol()->identifier()) :
string("spf_unnamed");
if (existingName == commonBlockName)
{
commonStat = current;
commonList = exp;
break;
}
}
}
if (commonStat)
commonStat = cur;
commonList = exp;
break;
}
}
current = current->lexNext();
if (commonStat)
break;
}
vector<SgExpression*> varRefs;
if (commonList)
{
SgExpression* varList = commonList->lhs();
if (varList)
{
auto extractSymbol = [](SgExpression* exp) -> SgSymbol* {
if (!exp)
return NULL;
if (exp->symbol())
return exp->symbol();
if (exp->lhs() && exp->lhs()->symbol())
return exp->lhs()->symbol();
return NULL;
};
if (varList->variant() == EXPR_LIST)
{
for (SgExpression* exp = varList; exp; exp = exp->rhs())
{
SgExpression* varExp = exp->lhs();
SgSymbol* sym = extractSymbol(varExp);
if (sym)
{
varRefs.push_back(new SgVarRefExp(sym));
}
}
}
else
{
for (SgExpression* varExp = varList; varExp; varExp = varExp->rhs())
{
SgSymbol* sym = extractSymbol(varExp);
if (sym)
{
varRefs.push_back(new SgVarRefExp(sym));
}
}
}
}
}
for (SgSymbol* sym : symbols)
varRefs.push_back(new SgVarRefExp(sym));
SgExpression* varList = makeExprList(varRefs, false);
if (!commonList)
{
current = funcStart->lexNext();
while (current != funcEnd && current)
{
if (current->variant() == COMM_STAT)
{
commonStat = current;
break;
}
current = current->lexNext();
}
SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str());
commonList = new SgExpression(COMM_LIST, NULL, NULL, commonSymbol);
commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol);
if (commonStat)
{
SgExpression* lastCommList = commonStat->expr(0);
if (lastCommList)
{
while (lastCommList->rhs())
lastCommList = lastCommList->rhs();
lastCommList->setRhs(commonList);
}
else
{
commonStat->setExpression(0, commonList);
}
}
else
{
commonStat = new SgStatement(COMM_STAT);
commonStat->setFileName(declPlace->fileName());
commonStat->setFileId(declPlace->getFileId());
commonStat->setProject(declPlace->getProject());
commonStat->setlineNumber(getNextNegativeLineNumber());
commonStat->setExpression(0, commonList);
declPlace->insertStmtBefore(*commonStat, *declPlace->controlParent());
}
commonStat = new SgStatement(COMM_STAT);
commonStat->setFileName(funcStart->fileName());
commonStat->setFileId(funcStart->getFileId());
commonStat->setProject(funcStart->getProject());
commonStat->setlineNumber(getNextNegativeLineNumber());
commonStat->setExpression(0, commonList);
SgStatement* lastDecl = FindLastDeclStatement(funcStart);
lastDecl->insertStmtAfter(*commonStat, *funcStart);
}
varRefs.push_back(new SgVarRefExp(varSymbol));
if (varRefs.size() > 0)
else
{
std::reverse(varRefs.begin(), varRefs.end());
SgExpression* varList = makeExprList(varRefs, false);
commonList->setLhs(varList);
}
return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr());
SgStatement* insertAfter = commonStat;
for (SgSymbol* sym : symbols)
{
SgStatement* declStmt = sym->makeVarDeclStmt();
if (!declStmt)
continue;
if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt))
vds->setVariant(VAR_DECL_90);
declStmt->setFileName(funcStart->fileName());
declStmt->setFileId(funcStart->getFileId());
declStmt->setProject(funcStart->getProject());
declStmt->setlineNumber(getNextNegativeLineNumber());
insertAfter->insertStmtAfter(*declStmt, *funcStart);
insertAfter = declStmt;
}
}
static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
@@ -227,6 +210,8 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map<
{
if (exp->symbol()->type()->variant() == T_STRING)
return;
if (changed.find(st) != changed.end())
return;
string expUnparsed = exp->unparse();
if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType())
{
@@ -234,12 +219,62 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map<
}
SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL);
newStatement->setFileId(st->getFileId());
newStatement->setFileId(st->getFileId());
newStatement->setProject(st->getProject());
newStatement->setlineNumber(getNextNegativeLineNumber());
newStatement->setLocalLineNumber(st->lineNumber());
st->insertStmtBefore(*newStatement, *st->controlParent());
changed.insert(st);
}
static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string array = exp->unparse();
arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType());
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
}
static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string varName = exp->unparse();
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol())
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
}
void ArrayConstantPropagation(SgProject& project)
@@ -262,40 +297,37 @@ void ArrayConstantPropagation(SgProject& project)
for (; st != lastNode; st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT)
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()))
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
else if (st->variant() == FOR_NODE)
if (st->variant() == FOR_NODE)
{
SgExpression* lowerBound = st->expr(0)->lhs();
SgExpression* upperBound = st->expr(0)->rhs();
string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse();
if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs()))
{
if (arrayToVariable.find(upperBoundUnparsed) == arrayToVariable.end())
{
arrayToVariable[upperBoundUnparsed] = CreateVar(variableNumber, upperBound->symbol()->type()->baseType());
}
TransformBorder(st, upperBound, arrayToVariable, variableNumber);
st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr());
}
else if (upperBound->variant() == VAR_REF)
CheckVariable(st, upperBound, arrayToVariable, variableNumber);
if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs()))
{
if (arrayToVariable.find(lowerBoundUnparsed) == arrayToVariable.end())
{
arrayToVariable[lowerBoundUnparsed] = CreateVar(variableNumber, lowerBound->symbol()->type()->baseType());
}
TransformBorder(st, lowerBound, arrayToVariable, variableNumber);
st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr());
}
else if (lowerBound->variant() == VAR_REF)
CheckVariable(st, lowerBound, arrayToVariable, variableNumber);
}
}
}
unordered_set<SgStatement*> funcStarts;
for (SgStatement* st : positionsToAdd)
{
SgStatement* scope = st->controlParent();
if (scope)
funcStarts.insert(scope);
}
for (const auto& st: funcStarts)
InsertCommonAndDeclsForFunction(st, variablesToAdd);
}
}

View File

@@ -33,10 +33,9 @@ static void RemoveEmptyPoints(ArrayAccessingIndexes& container)
points.push_back(arrayPoint);
}
if (points.size() < accessingSet.GetElements().size() && !points.empty())
if (!points.empty())
resultContainer[arrayName] = points;
if (points.empty())
else
toRemove.insert(arrayName);
}
@@ -281,7 +280,6 @@ static void SolveDataFlow(Region* DFG)
static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector<uint64_t>& declaredDims)
{
declaredDims.clear();
if (!arrayRef || !arrayRef->symbol() || !isSgArrayType(arrayRef->symbol()->type()))
return false;
SgArrayType* arrayType = (SgArrayType*)arrayRef->symbol()->type();
@@ -290,17 +288,22 @@ static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector<uint64_t>
{
SgExpression* sizeExpr = arrayType->sizeInDim(i);
SgConstantSymb* constValSymb = isSgConstantSymb(sizeExpr->symbol());
string strDimLength;
SgSubscriptExp* subscriptExpr = isSgSubscriptExp(sizeExpr);
uint64_t dimLength;
if (sizeExpr && sizeExpr->variant() == INT_VAL)
strDimLength = sizeExpr->unparse();
dimLength = stol(sizeExpr->unparse());
else if (constValSymb)
strDimLength = constValSymb->constantValue()->unparse();
dimLength = stol(constValSymb->constantValue()->unparse());
else if (subscriptExpr)
{
dimLength = stol(subscriptExpr->rhs()->unparse()) - stol(subscriptExpr->lhs()->unparse());
}
else
return false;
if (strDimLength == "0")
if (dimLength == 0)
return false;
declaredDims.push_back((uint64_t)stoi(strDimLength));
declaredDims.push_back(dimLength);
}
return true;
}
@@ -313,7 +316,8 @@ static bool CheckDimensionLength(const AccessingSet& array)
SgArrayRefExp* arrayRef = array.GetElements()[0][0].array;
if (!arrayRef)
return false;
vector<uint64_t> declaredDims(dimCount);
vector<uint64_t> declaredDims;
declaredDims.reserve(dimCount);
if (!getArrayDeclaredDimensions(arrayRef, declaredDims))
return false;
vector<ArrayDimension> testArray(dimCount);

View File

@@ -204,7 +204,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
current_dim = { start, step, iters, ref };
}
if (current_dim.start != 0 && current_dim.step != 0 && current_dim.tripCount != 0)
if (current_dim.step != 0 && current_dim.tripCount != 0)
{
accessPoint[n - index_vars.size()] = current_dim;
fillCount++;