#include "propagation.h" #include "../Utils/SgUtils.h" #include #include #include #include using namespace std; static SgStatement* declPlace = NULL; static unordered_set changed;; static bool CheckConstIndexes(SgExpression* exp) { if (!exp) { return true; } SgExpression* lhs = exp->lhs(); SgExpression* rhs = exp->rhs(); do { if (lhs->variant() != INT_VAL) { return false; } if (rhs) { lhs = rhs->lhs(); rhs = rhs->rhs(); } } while (rhs); return true; } static SgExpression* CreateVar(int& variableNumber, SgType* type) { string varName = "tmp_prop_var"; string name = varName + std::to_string(variableNumber) + "__"; variableNumber++; SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *declPlace->controlParent()); const string commonBlockName = "propagation_common__"; SgStatement* funcStart = declPlace->controlParent(); SgStatement* commonStat = NULL; SgExpression* commonList = NULL; SgStatement* funcEnd = funcStart->lastNodeOfStmt(); SgStatement* current = funcStart->lexNext(); while (current != funcEnd && current) { if (current->variant() == COMM_STAT) { for (SgExpression* exp = current->expr(0); exp; exp = exp->rhs()) { if (exp->variant() == COMM_LIST) { string existingName = exp->symbol() ? string(exp->symbol()->identifier()) : string("spf_unnamed"); if (existingName == commonBlockName) { commonStat = current; commonList = exp; break; } } } if (commonStat) break; } current = current->lexNext(); } vector varRefs; if (commonList) { SgExpression* varList = commonList->lhs(); if (varList) { auto extractSymbol = [](SgExpression* exp) -> SgSymbol* { if (!exp) return NULL; if (exp->symbol()) return exp->symbol(); if (exp->lhs() && exp->lhs()->symbol()) return exp->lhs()->symbol(); return NULL; }; if (varList->variant() == EXPR_LIST) { for (SgExpression* exp = varList; exp; exp = exp->rhs()) { SgExpression* varExp = exp->lhs(); SgSymbol* sym = extractSymbol(varExp); if (sym) { varRefs.push_back(new SgVarRefExp(sym)); } } } else { for (SgExpression* varExp = varList; varExp; varExp = varExp->rhs()) { SgSymbol* sym = extractSymbol(varExp); if (sym) { varRefs.push_back(new SgVarRefExp(sym)); } } } } } if (!commonList) { current = funcStart->lexNext(); while (current != funcEnd && current) { if (current->variant() == COMM_STAT) { commonStat = current; break; } current = current->lexNext(); } SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); commonList = new SgExpression(COMM_LIST, NULL, NULL, commonSymbol); if (commonStat) { SgExpression* lastCommList = commonStat->expr(0); if (lastCommList) { while (lastCommList->rhs()) lastCommList = lastCommList->rhs(); lastCommList->setRhs(commonList); } else { commonStat->setExpression(0, commonList); } } else { commonStat = new SgStatement(COMM_STAT); commonStat->setFileName(declPlace->fileName()); commonStat->setFileId(declPlace->getFileId()); commonStat->setProject(declPlace->getProject()); commonStat->setlineNumber(getNextNegativeLineNumber()); commonStat->setExpression(0, commonList); declPlace->insertStmtBefore(*commonStat, *declPlace->controlParent()); } } varRefs.push_back(new SgVarRefExp(varSymbol)); if (varRefs.size() > 0) { std::reverse(varRefs.begin(), varRefs.end()); SgExpression* varList = makeExprList(varRefs, false); commonList->setLhs(varList); } return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); } static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { if (!exp) { return; } vector subnodes = { exp->lhs(), exp->rhs() }; string expUnparsed; SgExpression* toAdd = NULL; if (exp->variant() == ARRAY_REF && CheckConstIndexes(exp->lhs())) { cout << st->unparse() << endl; if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); } st->setExpression(1, arrayToVariable[expUnparsed]->copyPtr()); return; } for (int i = 0; i < 2; i++) { if (subnodes[i] && subnodes[i]->variant() == ARRAY_REF && subnodes[i]->symbol()->type()->baseType() && CheckConstIndexes(subnodes[i]->lhs())) { expUnparsed = subnodes[i]->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, subnodes[i]->symbol()->type()->baseType());; } toAdd = arrayToVariable[expUnparsed]->copyPtr(); if (toAdd) { if (i == 0) { exp->setLhs(toAdd); } else { exp->setRhs(toAdd); } } } else { TransformRightPart(st, subnodes[i], arrayToVariable, variableNumber); } } } static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { if (exp->symbol()->type()->variant() == T_STRING) return; if (changed.find(st) != changed.end()) return; string expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); } SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL); newStatement->setFileId(st->getFileId()); newStatement->setProject(st->getProject()); newStatement->setlineNumber(getNextNegativeLineNumber()); newStatement->setLocalLineNumber(st->lineNumber()); st->insertStmtBefore(*newStatement, *st->controlParent()); changed.insert(st); } static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { SgStatement* firstStatement = declPlace->lexPrev(); st = st->lexPrev(); string array = exp->unparse(); arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); while (st != firstStatement) { if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { if (st->expr(1)) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); } if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); } } st = st->lexPrev(); } } static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) { SgStatement* firstStatement = declPlace->lexPrev(); st = st->lexPrev(); string varName = exp->unparse(); while (st != firstStatement) { if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol()) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); } if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { if (st->expr(1)) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); } if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); } } st = st->lexPrev(); } } void ArrayConstantPropagation(SgProject& project) { unordered_map arrayToVariable; int variableNumber = 0; for (int i = 0; i < project.numberOfFiles(); i++) { SgFile* file = &(project.file(i)); if (!file) continue; const int funcNum = file->numberOfFunctions(); for (int i = 0; i < funcNum; ++i) { SgStatement* st = file->functions(i); declPlace = st->lexNext(); SgStatement* lastNode = st->lastNodeOfStmt(); for (; st != lastNode; st = st->lexNext()) { if (st->variant() == FOR_NODE) { SgExpression* lowerBound = st->expr(0)->lhs(); SgExpression* upperBound = st->expr(0)->rhs(); string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse(); if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs())) { TransformBorder(st, upperBound, arrayToVariable, variableNumber); st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr()); } else if (upperBound->variant() == VAR_REF) CheckVariable(st, upperBound, arrayToVariable, variableNumber); if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs())) { TransformBorder(st, lowerBound, arrayToVariable, variableNumber); st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr()); } else if (lowerBound->variant() == VAR_REF) CheckVariable(st, lowerBound, arrayToVariable, variableNumber); } } cout << file->functions(i)->unparse() << endl; } } }