diff --git a/src/ArrayConstantPropagation/propagation.cpp b/src/ArrayConstantPropagation/propagation.cpp index bf56a3f..2e94e69 100644 --- a/src/ArrayConstantPropagation/propagation.cpp +++ b/src/ArrayConstantPropagation/propagation.cpp @@ -9,7 +9,10 @@ using namespace std; -static SgStatement* declPlace = NULL; +SgStatement* declPlace = NULL; +unordered_set changed; +unordered_set variablesToAdd; +unordered_set positionsToAdd; static bool CheckConstIndexes(SgExpression* exp) { @@ -40,137 +43,117 @@ static SgExpression* CreateVar(int& variableNumber, SgType* type) string name = varName + std::to_string(variableNumber) + "__"; variableNumber++; - SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *declPlace->controlParent()); + SgStatement* funcStart = declPlace->controlParent(); + SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *funcStart); + + variablesToAdd.insert(varSymbol); + positionsToAdd.insert(declPlace); + + return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); +} + +static SgStatement* FindLastDeclStatement(SgStatement* funcStart) +{ + SgStatement* endSt = funcStart->lastNodeOfStmt(); + SgStatement* cur = funcStart->lexNext(); + SgStatement* lastDecl = funcStart; + const set declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT, + EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL }; + + while (cur && cur != endSt) + { + if (cur->variant() == INTERFACE_STMT) + cur = cur->lastNodeOfStmt(); + + if (declVariants.find(cur->variant()) != declVariants.end()) + lastDecl = cur; + else if (isSgExecutableStatement(cur)) + break; + + cur = cur->lexNext(); + } + + return lastDecl; +} + +static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const unordered_set& symbols) +{ + if (symbols.empty()) + return; const string commonBlockName = "__propagation_common__"; - SgStatement* funcStart = declPlace->controlParent(); + SgStatement* funcEnd = funcStart->lastNodeOfStmt(); SgStatement* commonStat = NULL; SgExpression* commonList = NULL; - SgStatement* funcEnd = funcStart->lastNodeOfStmt(); - SgStatement* current = funcStart->lexNext(); - - while (current != funcEnd && current) + for (SgStatement* cur = funcStart->lexNext(); + cur && cur != funcEnd; cur = cur->lexNext()) { - if (current->variant() == COMM_STAT) + if (cur->variant() != COMM_STAT) + continue; + + for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs()) { - for (SgExpression* exp = current->expr(0); exp; exp = exp->rhs()) + if (exp->variant() != COMM_LIST) + continue; + + const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL; + string existingName = id ? string(id) : string("spf_unnamed"); + if (existingName == commonBlockName) { - if (exp->variant() == COMM_LIST) - { - string existingName = exp->symbol() ? - string(exp->symbol()->identifier()) : - string("spf_unnamed"); - if (existingName == commonBlockName) - { - commonStat = current; - commonList = exp; - break; - } - } - } - if (commonStat) + commonStat = cur; + commonList = exp; break; + } } - current = current->lexNext(); + if (commonStat) + break; } vector varRefs; - if (commonList) - { - SgExpression* varList = commonList->lhs(); - if (varList) - { - auto extractSymbol = [](SgExpression* exp) -> SgSymbol* { - if (!exp) - return NULL; - if (exp->symbol()) - return exp->symbol(); - if (exp->lhs() && exp->lhs()->symbol()) - return exp->lhs()->symbol(); - return NULL; - }; - if (varList->variant() == EXPR_LIST) - { - for (SgExpression* exp = varList; exp; exp = exp->rhs()) - { - SgExpression* varExp = exp->lhs(); - SgSymbol* sym = extractSymbol(varExp); - if (sym) - { - varRefs.push_back(new SgVarRefExp(sym)); - } - } - } - else - { - for (SgExpression* varExp = varList; varExp; varExp = varExp->rhs()) - { - SgSymbol* sym = extractSymbol(varExp); - if (sym) - { - varRefs.push_back(new SgVarRefExp(sym)); - } - } - } - } - } + for (SgSymbol* sym : symbols) + varRefs.push_back(new SgVarRefExp(sym)); + SgExpression* varList = makeExprList(varRefs, false); if (!commonList) { - current = funcStart->lexNext(); - while (current != funcEnd && current) - { - if (current->variant() == COMM_STAT) - { - commonStat = current; - break; - } - current = current->lexNext(); - } - SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); - commonList = new SgExpression(COMM_LIST, NULL, NULL, commonSymbol); + commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol); - if (commonStat) - { - SgExpression* lastCommList = commonStat->expr(0); - if (lastCommList) - { - while (lastCommList->rhs()) - lastCommList = lastCommList->rhs(); - lastCommList->setRhs(commonList); - } - else - { - commonStat->setExpression(0, commonList); - } - } - else - { - commonStat = new SgStatement(COMM_STAT); - commonStat->setFileName(declPlace->fileName()); - commonStat->setFileId(declPlace->getFileId()); - commonStat->setProject(declPlace->getProject()); - commonStat->setlineNumber(getNextNegativeLineNumber()); - commonStat->setExpression(0, commonList); - - declPlace->insertStmtBefore(*commonStat, *declPlace->controlParent()); - } + commonStat = new SgStatement(COMM_STAT); + commonStat->setFileName(funcStart->fileName()); + commonStat->setFileId(funcStart->getFileId()); + commonStat->setProject(funcStart->getProject()); + commonStat->setlineNumber(getNextNegativeLineNumber()); + commonStat->setExpression(0, commonList); + SgStatement* lastDecl = FindLastDeclStatement(funcStart); + lastDecl->insertStmtAfter(*commonStat, *funcStart); } - varRefs.push_back(new SgVarRefExp(varSymbol)); - - if (varRefs.size() > 0) + else { - std::reverse(varRefs.begin(), varRefs.end()); - SgExpression* varList = makeExprList(varRefs, false); - commonList->setLhs(varList); } - return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); + SgStatement* insertAfter = commonStat; + for (SgSymbol* sym : symbols) + { + SgStatement* declStmt = sym->makeVarDeclStmt(); + if (!declStmt) + continue; + + if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt)) + vds->setVariant(VAR_DECL_90); + + declStmt->setFileName(funcStart->fileName()); + declStmt->setFileId(funcStart->getFileId()); + declStmt->setProject(funcStart->getProject()); + declStmt->setlineNumber(getNextNegativeLineNumber()); + + insertAfter->insertStmtAfter(*declStmt, *funcStart); + insertAfter = declStmt; + } } static void TransformRightPart(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) @@ -227,6 +210,8 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map< { if (exp->symbol()->type()->variant() == T_STRING) return; + if (changed.find(st) != changed.end()) + return; string expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { @@ -234,12 +219,62 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map< } SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL); - newStatement->setFileId(st->getFileId()); + newStatement->setFileId(st->getFileId()); newStatement->setProject(st->getProject()); newStatement->setlineNumber(getNextNegativeLineNumber()); newStatement->setLocalLineNumber(st->lineNumber()); st->insertStmtBefore(*newStatement, *st->controlParent()); + changed.insert(st); +} + +static void TransformBorder(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) +{ + SgStatement* firstStatement = declPlace->lexPrev(); + st = st->lexPrev(); + string array = exp->unparse(); + arrayToVariable[array] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); + while (st != firstStatement) + { + if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) + { + if (st->expr(1)) + { + TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); + } + if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) + { + TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); + } + } + st = st->lexPrev(); + } +} + +static void CheckVariable(SgStatement* st, SgExpression* exp, unordered_map& arrayToVariable, int& variableNumber) +{ + SgStatement* firstStatement = declPlace->lexPrev(); + st = st->lexPrev(); + string varName = exp->unparse(); + while (st != firstStatement) + { + if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol()) + { + TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); + } + if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) + { + if (st->expr(1)) + { + TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); + } + if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) + { + TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); + } + } + st = st->lexPrev(); + } } void ArrayConstantPropagation(SgProject& project) @@ -262,40 +297,37 @@ void ArrayConstantPropagation(SgProject& project) for (; st != lastNode; st = st->lexNext()) { - if (st->variant() == ASSIGN_STAT) - { - if (st->expr(1)) - { - TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); - } - if (st->expr(0) && st->expr(0)->variant() == ARRAY_REF && CheckConstIndexes(st->expr(0)->lhs())) - { - TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); - } - } - else if (st->variant() == FOR_NODE) + if (st->variant() == FOR_NODE) { SgExpression* lowerBound = st->expr(0)->lhs(); SgExpression* upperBound = st->expr(0)->rhs(); string lowerBoundUnparsed = lowerBound->unparse(), upperBoundUnparsed = upperBound->unparse(); if (upperBound->variant() == ARRAY_REF && upperBound->symbol()->type()->baseType() && CheckConstIndexes(upperBound->lhs())) { - if (arrayToVariable.find(upperBoundUnparsed) == arrayToVariable.end()) - { - arrayToVariable[upperBoundUnparsed] = CreateVar(variableNumber, upperBound->symbol()->type()->baseType()); - } + TransformBorder(st, upperBound, arrayToVariable, variableNumber); st->expr(0)->setRhs(arrayToVariable[upperBoundUnparsed]->copyPtr()); } + else if (upperBound->variant() == VAR_REF) + CheckVariable(st, upperBound, arrayToVariable, variableNumber); + if (lowerBound->variant() == ARRAY_REF && lowerBound->symbol()->type()->baseType() && CheckConstIndexes(lowerBound->lhs())) { - if (arrayToVariable.find(lowerBoundUnparsed) == arrayToVariable.end()) - { - arrayToVariable[lowerBoundUnparsed] = CreateVar(variableNumber, lowerBound->symbol()->type()->baseType()); - } + TransformBorder(st, lowerBound, arrayToVariable, variableNumber); st->expr(0)->setLhs(arrayToVariable[lowerBoundUnparsed]->copyPtr()); } + else if (lowerBound->variant() == VAR_REF) + CheckVariable(st, lowerBound, arrayToVariable, variableNumber); } } } + unordered_set funcStarts; + for (SgStatement* st : positionsToAdd) + { + SgStatement* scope = st->controlParent(); + if (scope) + funcStarts.insert(scope); + } + for (const auto& st: funcStarts) + InsertCommonAndDeclsForFunction(st, variablesToAdd); } } \ No newline at end of file diff --git a/src/PrivateAnalyzer/private_arrays_search.cpp b/src/PrivateAnalyzer/private_arrays_search.cpp index 95624f5..e50ae25 100644 --- a/src/PrivateAnalyzer/private_arrays_search.cpp +++ b/src/PrivateAnalyzer/private_arrays_search.cpp @@ -33,10 +33,9 @@ static void RemoveEmptyPoints(ArrayAccessingIndexes& container) points.push_back(arrayPoint); } - if (points.size() < accessingSet.GetElements().size() && !points.empty()) + if (!points.empty()) resultContainer[arrayName] = points; - - if (points.empty()) + else toRemove.insert(arrayName); } @@ -281,7 +280,6 @@ static void SolveDataFlow(Region* DFG) static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector& declaredDims) { - declaredDims.clear(); if (!arrayRef || !arrayRef->symbol() || !isSgArrayType(arrayRef->symbol()->type())) return false; SgArrayType* arrayType = (SgArrayType*)arrayRef->symbol()->type(); @@ -290,17 +288,22 @@ static bool getArrayDeclaredDimensions(SgArrayRefExp* arrayRef, vector { SgExpression* sizeExpr = arrayType->sizeInDim(i); SgConstantSymb* constValSymb = isSgConstantSymb(sizeExpr->symbol()); - string strDimLength; + SgSubscriptExp* subscriptExpr = isSgSubscriptExp(sizeExpr); + uint64_t dimLength; if (sizeExpr && sizeExpr->variant() == INT_VAL) - strDimLength = sizeExpr->unparse(); + dimLength = stol(sizeExpr->unparse()); else if (constValSymb) - strDimLength = constValSymb->constantValue()->unparse(); + dimLength = stol(constValSymb->constantValue()->unparse()); + else if (subscriptExpr) + { + dimLength = stol(subscriptExpr->rhs()->unparse()) - stol(subscriptExpr->lhs()->unparse()); + } else return false; - if (strDimLength == "0") + if (dimLength == 0) return false; - declaredDims.push_back((uint64_t)stoi(strDimLength)); + declaredDims.push_back(dimLength); } return true; } @@ -313,7 +316,8 @@ static bool CheckDimensionLength(const AccessingSet& array) SgArrayRefExp* arrayRef = array.GetElements()[0][0].array; if (!arrayRef) return false; - vector declaredDims(dimCount); + vector declaredDims; + declaredDims.reserve(dimCount); if (!getArrayDeclaredDimensions(arrayRef, declaredDims)) return false; vector testArray(dimCount); diff --git a/src/PrivateAnalyzer/region.cpp b/src/PrivateAnalyzer/region.cpp index 5252a7f..4a0ca2d 100644 --- a/src/PrivateAnalyzer/region.cpp +++ b/src/PrivateAnalyzer/region.cpp @@ -206,7 +206,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces current_dim = { start, step, iters, ref }; } - if (current_dim.start != 0 && current_dim.step != 0 && current_dim.tripCount != 0) + if (current_dim.step != 0 && current_dim.tripCount != 0) { accessPoint[n - index_vars.size()] = current_dim; fillCount++;