#include "propagation.h" #include "../Utils/SgUtils.h" #include #include #include #include #include using namespace std; namespace { struct MyHash { size_t operator()(const SgSymbol* s) const { return std::hash{}(s->identifier()); } }; struct MyEq { bool operator()(const SgSymbol* a, const SgSymbol* b) const { return std::strcmp(a->identifier(), b->identifier()) == 0; } }; } SgStatement* declPlace = NULL; unordered_set changed; unordered_set variablesToAdd; unordered_set positionsToAdd; unordered_set statementsToRemove; unordered_map>> expToChange; static bool CheckConstIndexes(SgExpression* exp) { if (!exp) { return true; } SgExpression* lhs = exp->lhs(); SgExpression* rhs = exp->rhs(); do { if (lhs->variant() != INT_VAL) { return false; } if (rhs) { lhs = rhs->lhs(); rhs = rhs->rhs(); } } while (rhs); return true; } static SgExpression* CreateVar(int& variableNumber, SgType* type) { string varName = "__tmp_prop_var"; string name = varName + std::to_string(variableNumber) + "__"; variableNumber++; SgStatement* funcStart = declPlace->controlParent(); SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *funcStart); variablesToAdd.insert(varSymbol); positionsToAdd.insert(declPlace); return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); } static SgStatement* FindLastDeclStatement(SgStatement* funcStart) { SgStatement* endSt = funcStart->lastNodeOfStmt(); SgStatement* cur = funcStart->lexNext(); SgStatement* lastDecl = funcStart; const set declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT, EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL }; while (cur && cur != endSt) { if (cur->variant() == INTERFACE_STMT) cur = cur->lastNodeOfStmt(); if (declVariants.find(cur->variant()) != declVariants.end()) lastDecl = cur; else if (isSgExecutableStatement(cur)) break; cur = cur->lexNext(); } return lastDecl; } static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const unordered_set& symbols) { if (symbols.empty()) return; const string commonBlockName = "__propagation_common__"; SgStatement* funcEnd = funcStart->lastNodeOfStmt(); SgStatement* commonStat = NULL; SgExpression* commonList = NULL; for (SgStatement* cur = funcStart->lexNext(); cur && cur != funcEnd; cur = cur->lexNext()) { if (cur->variant() != COMM_STAT) continue; for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs()) { if (exp->variant() != COMM_LIST) continue; const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL; string existingName = id ? string(id) : string("spf_unnamed"); if (existingName == commonBlockName) { commonStat = cur; commonList = exp; break; } } if (commonStat) break; } vector varRefs; for (SgSymbol* sym : symbols) { if (!sym || sym->variant() != VARIABLE_NAME || string(sym->identifier()) == commonBlockName) continue; SgSymbol* symToAdd = new SgSymbol(VARIABLE_NAME, sym->identifier(), *sym->type(), *funcStart); varRefs.push_back(new SgVarRefExp(symToAdd)); } SgExpression* varList = makeExprList(varRefs, false); SgStatement* insertAfter = FindLastDeclStatement(funcStart); for (SgSymbol* sym : symbols) { SgStatement* declStmt = sym->makeVarDeclStmt(); if (!declStmt) continue; if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt)) vds->setVariant(VAR_DECL_90); declStmt->setFileName(funcStart->fileName()); declStmt->setFileId(funcStart->getFileId()); declStmt->setProject(funcStart->getProject()); declStmt->setlineNumber(getNextNegativeLineNumber()); insertAfter->insertStmtAfter(*declStmt, *funcStart); insertAfter = declStmt; statementsToRemove.insert(declStmt); } if (!commonList) { SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol); commonStat = new SgStatement(COMM_STAT); commonStat->setFileName(funcStart->fileName()); commonStat->setFileId(funcStart->getFileId()); commonStat->setProject(funcStart->getProject()); commonStat->setlineNumber(getNextNegativeLineNumber()); commonStat->setExpression(0, commonList); SgStatement* lastDecl = FindLastDeclStatement(funcStart); lastDecl->insertStmtAfter(*commonStat, *funcStart); statementsToRemove.insert(commonStat); } else { commonList->setLhs(varList); } } static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { if (!exp) return; vector subnodes = { exp->lhs(), exp->rhs() }; string expUnparsed; SgExpression* toAdd = NULL; if (exp->variant() == ARRAY_REF && CheckConstIndexes(exp->lhs())) { expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); } positionsToAdd.insert(declPlace); SgSymbol* builder = arrayToVariable[expUnparsed]->symbol(); auto* sym = new SgSymbol(builder->variant(), builder->identifier(), builder->type(), st->controlParent()); auto* newVarExp = new SgVarRefExp(sym); expToChange[st->fileName()].push_back({ st , st->copyPtr() }); st->setExpression(1, newVarExp); return; } for (int i = 0; i < 2; i++) { if (subnodes[i] && subnodes[i]->variant() == ARRAY_REF && subnodes[i]->symbol()->type()->baseType() && CheckConstIndexes(subnodes[i]->lhs())) { expUnparsed = subnodes[i]->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end()) arrayToVariable[expUnparsed] = CreateVar(variableNumber, subnodes[i]->symbol()->type()->baseType()); positionsToAdd.insert(declPlace); SgSymbol* builder = arrayToVariable[expUnparsed]->symbol(); auto* sym = new SgSymbol(builder->variant(), builder->identifier(), builder->type(), st->controlParent()); toAdd = new SgVarRefExp(sym); if (toAdd) { if (i == 0) { expToChange[st->fileName()].push_back({ st , st->copyPtr() });; exp->setLhs(toAdd); } else { expToChange[st->fileName()].push_back({ st , st->copyPtr() });; exp->setRhs(toAdd); } } } else TransformRightPart(st, subnodes[i], arrayToVariable, variableNumber); } } static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { if (exp->symbol()->type()->variant() == T_STRING) return; if (changed.find(st) != changed.end()) return; string expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); } positionsToAdd.insert(declPlace); SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL); newStatement->setFileId(st->getFileId()); newStatement->setProject(st->getProject()); newStatement->setlineNumber(getNextNegativeLineNumber()); newStatement->setLocalLineNumber(st->lineNumber()); st->insertStmtBefore(*newStatement, *st->controlParent()); changed.insert(st); statementsToRemove.insert(newStatement); } static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { SgStatement* firstStatement = declPlace->lexPrev(); st = st->lexPrev(); string array = exp->unparse(); if (arrayToVariable.find(array) == arrayToVariable.end()) arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); while (st != firstStatement) { if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { if (st->expr(1)) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); } if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); } st = st->lexPrev(); } } static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { SgStatement* firstStatement = declPlace->lexPrev(); st = st->lexPrev(); string varName = exp->unparse(); while (st != firstStatement) { if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol()) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); positionsToAdd.insert(declPlace); } if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { if (st->expr(1)) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); positionsToAdd.insert(declPlace); } if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); positionsToAdd.insert(declPlace); } } st = st->lexPrev(); } } void ArrayConstantPropagation(SgProject& project) { unordered_map arrayToVariable; int variableNumber = 0; for (int i = 0; i < project.numberOfFiles(); i++) { SgFile* file = &(project.file(i)); if (!file) continue; SgFile::switchToFile(file->filename()); const int funcNum = file->numberOfFunctions(); for (int i = 0; i < funcNum; ++i) { SgStatement* st = file->functions(i); declPlace = st->lexNext(); SgStatement* lastNode = st->lastNodeOfStmt(); for (; st != lastNode; st = st->lexNext()) { if (st->variant() == FOR_NODE) { SgExpression* lowerBound = st->expr(0)->lhs(); SgExpression* upperBound = st->expr(0)->rhs(); SgStatement* boundCopy = NULL; string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse(); if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs())) { boundCopy = st->copyPtr(); TransformBorder(st, upperBound, arrayToVariable, variableNumber); st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr()); expToChange[st->fileName()].push_back({ st ,boundCopy });; positionsToAdd.insert(declPlace); } else if (upperBound->variant() == VAR_REF) CheckVariable(st, upperBound, arrayToVariable, variableNumber); if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs())) { boundCopy = st->copyPtr(); TransformBorder(st, lowerBound, arrayToVariable, variableNumber); st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr()); expToChange[st->fileName()].push_back({ st , boundCopy });; positionsToAdd.insert(declPlace); } else if (lowerBound->variant() == VAR_REF) CheckVariable(st, lowerBound, arrayToVariable, variableNumber); } } } } unordered_set funcStarts; for (SgStatement* st : positionsToAdd) { SgFile::switchToFile(st->fileName()); SgStatement* scope = st->controlParent(); if (scope) funcStarts.insert(scope); } for (const auto& st : funcStarts) { SgFile::switchToFile(st->fileName()); InsertCommonAndDeclsForFunction(st, variablesToAdd); } }