Files
SAPFOR/src/ArrayConstantPropagation/propagation.cpp

333 lines
11 KiB
C++
Raw Normal View History

2025-12-19 01:55:23 +03:00
#include "propagation.h"
#include "../Utils/SgUtils.h"
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;
2026-03-12 04:25:45 +03:00
SgStatement* declPlace = NULL;
unordered_set<SgStatement*> changed;
unordered_set<SgSymbol*> variablesToAdd;
unordered_set<SgStatement*> positionsToAdd;
2025-12-19 01:55:23 +03:00
static bool CheckConstIndexes(SgExpression* exp)
{
if (!exp)
{
return true;
}
SgExpression* lhs = exp->lhs();
SgExpression* rhs = exp->rhs();
do
{
if (lhs->variant() != INT_VAL)
{
return false;
}
if (rhs)
{
lhs = rhs->lhs();
rhs = rhs->rhs();
}
} while (rhs);
return true;
}
static SgExpression* CreateVar(int& variableNumber, SgType* type)
{
string varName = "__tmp_prop_var";
string name = varName + std::to_string(variableNumber) + "__";
variableNumber++;
2026-03-12 04:25:45 +03:00
SgStatement* funcStart = declPlace->controlParent();
SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *funcStart);
variablesToAdd.insert(varSymbol);
positionsToAdd.insert(declPlace);
return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr());
}
static SgStatement* FindLastDeclStatement(SgStatement* funcStart)
{
SgStatement* endSt = funcStart->lastNodeOfStmt();
SgStatement* cur = funcStart->lexNext();
SgStatement* lastDecl = funcStart;
const set<int> declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT,
EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL };
while (cur && cur != endSt)
{
if (cur->variant() == INTERFACE_STMT)
cur = cur->lastNodeOfStmt();
if (declVariants.find(cur->variant()) != declVariants.end())
lastDecl = cur;
else if (isSgExecutableStatement(cur))
break;
cur = cur->lexNext();
}
return lastDecl;
}
static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const unordered_set<SgSymbol*>& symbols)
{
if (symbols.empty())
return;
2025-12-19 01:55:23 +03:00
const string commonBlockName = "__propagation_common__";
2026-03-12 04:25:45 +03:00
SgStatement* funcEnd = funcStart->lastNodeOfStmt();
2025-12-19 01:55:23 +03:00
SgStatement* commonStat = NULL;
SgExpression* commonList = NULL;
2026-03-12 04:25:45 +03:00
for (SgStatement* cur = funcStart->lexNext();
cur && cur != funcEnd; cur = cur->lexNext())
2025-12-19 01:55:23 +03:00
{
2026-03-12 04:25:45 +03:00
if (cur->variant() != COMM_STAT)
continue;
for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs())
2025-12-19 01:55:23 +03:00
{
2026-03-12 04:25:45 +03:00
if (exp->variant() != COMM_LIST)
continue;
const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL;
string existingName = id ? string(id) : string("spf_unnamed");
if (existingName == commonBlockName)
2025-12-19 01:55:23 +03:00
{
2026-03-12 04:25:45 +03:00
commonStat = cur;
commonList = exp;
2025-12-19 01:55:23 +03:00
break;
2026-03-12 04:25:45 +03:00
}
2025-12-19 01:55:23 +03:00
}
2026-03-12 04:25:45 +03:00
if (commonStat)
break;
2025-12-19 01:55:23 +03:00
}
vector<SgExpression*> varRefs;
2026-03-12 04:25:45 +03:00
for (SgSymbol* sym : symbols)
varRefs.push_back(new SgVarRefExp(sym));
SgExpression* varList = makeExprList(varRefs, false);
2025-12-19 01:55:23 +03:00
if (!commonList)
{
SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str());
2026-03-12 04:25:45 +03:00
commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol);
2025-12-19 01:55:23 +03:00
2026-03-12 04:25:45 +03:00
commonStat = new SgStatement(COMM_STAT);
commonStat->setFileName(funcStart->fileName());
commonStat->setFileId(funcStart->getFileId());
commonStat->setProject(funcStart->getProject());
commonStat->setlineNumber(getNextNegativeLineNumber());
commonStat->setExpression(0, commonList);
2025-12-19 01:55:23 +03:00
2026-03-12 04:25:45 +03:00
SgStatement* lastDecl = FindLastDeclStatement(funcStart);
lastDecl->insertStmtAfter(*commonStat, *funcStart);
2025-12-19 01:55:23 +03:00
}
2026-03-12 04:25:45 +03:00
else
2025-12-19 01:55:23 +03:00
{
commonList->setLhs(varList);
}
2026-03-12 04:25:45 +03:00
SgStatement* insertAfter = commonStat;
for (SgSymbol* sym : symbols)
{
SgStatement* declStmt = sym->makeVarDeclStmt();
if (!declStmt)
continue;
if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt))
vds->setVariant(VAR_DECL_90);
declStmt->setFileName(funcStart->fileName());
declStmt->setFileId(funcStart->getFileId());
declStmt->setProject(funcStart->getProject());
declStmt->setlineNumber(getNextNegativeLineNumber());
insertAfter->insertStmtAfter(*declStmt, *funcStart);
insertAfter = declStmt;
}
2025-12-19 01:55:23 +03:00
}
static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
if (!exp)
{
return;
}
vector<SgExpression*> subnodes = { exp->lhs(), exp->rhs() };
string expUnparsed;
SgExpression* toAdd = NULL;
if (exp->variant() == ARRAY_REF && CheckConstIndexes(exp->lhs()))
{
cout << st->unparse() << endl;
if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType())
{
arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType());
}
st->setExpression(1, arrayToVariable[expUnparsed]->copyPtr());
return;
}
for (int i = 0; i < 2; i++)
{
if (subnodes[i] && subnodes[i]->variant() == ARRAY_REF && subnodes[i]->symbol()->type()->baseType() && CheckConstIndexes(subnodes[i]->lhs()))
{
expUnparsed = subnodes[i]->unparse();
if (arrayToVariable.find(expUnparsed) == arrayToVariable.end())
{
arrayToVariable[expUnparsed] = CreateVar(variableNumber, subnodes[i]->symbol()->type()->baseType());;
}
toAdd = arrayToVariable[expUnparsed]->copyPtr();
if (toAdd)
{
if (i == 0)
{
exp->setLhs(toAdd);
}
else
{
exp->setRhs(toAdd);
}
}
}
else
{
TransformRightPart(st, subnodes[i], arrayToVariable, variableNumber);
}
}
}
static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
if (exp->symbol()->type()->variant() == T_STRING)
return;
2026-03-12 04:25:45 +03:00
if (changed.find(st) != changed.end())
return;
2025-12-19 01:55:23 +03:00
string expUnparsed = exp->unparse();
if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType())
{
arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType());
}
SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL);
2026-03-12 04:25:45 +03:00
newStatement->setFileId(st->getFileId());
2025-12-19 01:55:23 +03:00
newStatement->setProject(st->getProject());
newStatement->setlineNumber(getNextNegativeLineNumber());
newStatement->setLocalLineNumber(st->lineNumber());
st->insertStmtBefore(*newStatement, *st->controlParent());
2026-03-12 04:25:45 +03:00
changed.insert(st);
}
static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string array = exp->unparse();
arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType());
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
}
static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map<string, SgExpression*>& arrayToVariable, int& variableNumber)
{
SgStatement* firstStatement = declPlace->lexPrev();
st = st->lexPrev();
string varName = exp->unparse();
while (st != firstStatement)
{
if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol())
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
if (st->expr(1))
{
TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber);
}
if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end())
{
TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber);
}
}
st = st->lexPrev();
}
2025-12-19 01:55:23 +03:00
}
void ArrayConstantPropagation(SgProject& project)
{
unordered_map<string, SgExpression*> arrayToVariable;
int variableNumber = 0;
for (int i = 0; i < project.numberOfFiles(); i++)
{
SgFile* file = &(project.file(i));
if (!file)
continue;
const int funcNum = file->numberOfFunctions();
for (int i = 0; i < funcNum; ++i)
{
SgStatement* st = file->functions(i);
declPlace = st->lexNext();
SgStatement* lastNode = st->lastNodeOfStmt();
for (; st != lastNode; st = st->lexNext())
{
2026-03-12 04:25:45 +03:00
if (st->variant() == FOR_NODE)
2025-12-19 01:55:23 +03:00
{
SgExpression* lowerBound = st->expr(0)->lhs();
SgExpression* upperBound = st->expr(0)->rhs();
string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse();
if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs()))
{
2026-03-12 04:25:45 +03:00
TransformBorder(st, upperBound, arrayToVariable, variableNumber);
2025-12-19 01:55:23 +03:00
st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr());
}
2026-03-12 04:25:45 +03:00
else if (upperBound->variant() == VAR_REF)
CheckVariable(st, upperBound, arrayToVariable, variableNumber);
2025-12-19 01:55:23 +03:00
if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs()))
{
2026-03-12 04:25:45 +03:00
TransformBorder(st, lowerBound, arrayToVariable, variableNumber);
2025-12-19 01:55:23 +03:00
st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr());
}
2026-03-12 04:25:45 +03:00
else if (lowerBound->variant() == VAR_REF)
CheckVariable(st, lowerBound, arrayToVariable, variableNumber);
2025-12-19 01:55:23 +03:00
}
}
}
2026-03-12 04:25:45 +03:00
unordered_set<SgStatement*> funcStarts;
for (SgStatement* st : positionsToAdd)
{
SgStatement* scope = st->controlParent();
if (scope)
funcStarts.insert(scope);
}
for (const auto& st: funcStarts)
InsertCommonAndDeclsForFunction(st, variablesToAdd);
2025-12-19 01:55:23 +03:00
}
}