Files
SAPFOR/Sapfor/_src/Transformations/loops_combiner.cpp
2025-03-25 20:39:29 +03:00

1837 lines
60 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "loops_combiner.h"
#include "../LoopAnalyzer/loop_analyzer.h"
#include "../ExpressionTransform/expr_transform.h"
#include "../Utils/errors.h"
#include "../Utils/SgUtils.h"
#include <string>
#include <vector>
#include <queue>
using std::string;
using std::vector;
using std::map;
using std::set;
using std::pair;
using std::make_pair;
using std::queue;
using std::wstring;
static int gcd(int a, int b)
{
while (a != b)
{
if (a > b)
a = a - b;
else
b = b - a;
}
return a;
}
static SgSymbol* getLoopSymbol(const LoopGraph* loop)
{
if (!loop || !loop->isFor)
return NULL;
SgForStmt* stmt = (SgForStmt*)loop->loop->GetOriginal();
return stmt->doName();
}
static void fillIterationVariables(const LoopGraph* loop, set<SgSymbol*>& vars, int dimensions = -1)
{
if (dimensions == -1)
{
auto s = getLoopSymbol(loop);
if (s)
vars.insert(s);
for (LoopGraph* child : loop->children)
fillIterationVariables(child, vars);
}
else
{
for (int i = 0; i < dimensions; ++i)
{
auto s = getLoopSymbol(loop);
if (s)
vars.insert(s);
if (i != dimensions - 1)
loop = loop->children[0];
}
}
}
static void eraseSymbolFromSet(set<SgSymbol*>& symbols, SgSymbol* symbol)
{
SgSymbol* toDelete = NULL;
for (SgSymbol* elem : symbols)
{
if (isEqSymbols(elem, symbol))
{
toDelete = elem;
break;
}
}
if (toDelete)
symbols.erase(toDelete);
}
static bool isSymbolInSet(const set<SgSymbol*>& symbols, SgSymbol* symbol)
{
for (SgSymbol* elem : symbols)
if (isEqSymbols(elem, symbol))
return true;
return false;
}
static void getIntersection(const set<SgSymbol*>& firstSet, const set<SgSymbol*>& secondSet, set<SgSymbol*>& intersection)
{
for (SgSymbol* var1 : firstSet)
{
for (SgSymbol* var2 : secondSet)
{
if (isEqSymbols(var1, var2))
{
intersection.insert(var1);
break;
}
}
}
}
static bool hasGotoToStatement(SgStatement* stmt)
{
if (!stmt->hasLabel())
return false;
SgStatement* parent;
parent = getFuncStat(stmt);
for (SgStatement* current = parent->lexNext(); current != parent->lastNodeOfStmt(); current = current->lexNext())
{
if (current->variant() == GOTO_NODE)
{
SgLabel* label = ((SgGotoStmt*)current)->branchLabel();
if (label->id() == stmt->label()->id())
return true;
}
}
return false;
}
// Проверить на равенство expr1 и expr2
static bool isEqExpressions(SgExpression* exp1, SgExpression* exp2)
{
string str1, str2;
if (exp1 != NULL)
{
exp1 = CalculateInteger(exp1->copyPtr());
str1 = exp1->unparse();
}
if (exp2 != NULL)
{
exp2 = CalculateInteger(exp2->copyPtr());
str2 = exp2->unparse();
}
return str1 == str2;
}
// Проверить, что expr1 и expr2 противоположны по значению
static bool isOppositeExpressions(SgExpression* exp1, SgExpression* exp2)
{
if (exp1 == NULL || exp2 == NULL)
return false;
exp1 = CalculateInteger(exp1->copyPtr());
exp2 = CalculateInteger(exp2->copyPtr());
if (exp1->variant() == MINUS_OP)
return isEqExpressions(exp1->lhs(), exp2);
if (exp2->variant() == MINUS_OP)
return isEqExpressions(exp1, exp2->lhs());
if (exp1->variant() == INT_VAL && exp2->variant() == INT_VAL)
return exp1->valueInteger() == -1 * exp2->valueInteger();
return false;
}
static bool ifLoopCanBeReversed(LoopGraph* loop, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph)
{
const set<string> privVars;
auto dependency = depInfoForLoopGraph.find(loop);
if (dependency == depInfoForLoopGraph.end())
return false;
vector<depNode*> nodes = (dependency->second)->getNodes();
for (depNode* node : nodes)
{
int type = node->typedep;
const ddnature nature = (ddnature)node->kinddep;
if (type == ARRAYDEP && (nature == ddoutput || nature == ddreduce))
continue;
if (type == PRIVATEDEP)
continue;
return false;
}
return true;
}
static void reverseLoop(LoopGraph* loop, int dimensions)
{
if (loop == NULL)
return;
for (int i = 0; i < dimensions; ++i)
{
if (loop->calculatedCountOfIters != 0) {
std::swap(loop->startVal, loop->endVal);
loop->stepVal *= -1;
}
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression& start = loopStmt->start()->copy();
loopStmt->setStart(loopStmt->end()->copy());
loopStmt->setEnd(start);
SgExpression* tmpEx = loopStmt->step();
if (tmpEx == NULL)
tmpEx = new SgValueExp(1);
if (tmpEx->variant() == MINUS_OP)
{
SgExpression* lhs = tmpEx->lhs();
loopStmt->setStep(*lhs);
}
else
loopStmt->setStep(*(new SgExpression(MINUS_OP, tmpEx, NULL)));
Expression* startExpr = loopStmt->start() ? new Expression(loopStmt->start()) : NULL;
Expression* endExpr = loopStmt->end() ? new Expression(loopStmt->end()) : NULL;
Expression* stepExpr = loopStmt->step() ? new Expression(loopStmt->step()) : NULL;
loop->startEndStepVals = std::make_tuple(startExpr, endExpr, stepExpr);
if (i != dimensions - 1)
loop = loop->children[0];
}
}
static bool isSimpleExpression(SgExpression* expr)
{
// simple expression:
// CONST_REF / INT_VAL / VAR_REF / MINUS_OP VAR_REF /
// (VAR_REF (ADD_OP / SUBT_OP) INT_VAL) / (INT_VAL (ADD_OP / SUBT_OP) VAR_REF)
SgExpression* lhs = expr->lhs(), *rhs = expr->rhs();
SgConstantSymb* constExpr = NULL;
switch (expr->variant()) {
case CONST_REF:
constExpr = isSgConstantSymb(expr->symbol());
if (constExpr && constExpr->constantValue()->isInteger())
return true;
else
return false;
case VAR_REF:
case INT_VAL:
return true;
case MINUS_OP:
if (expr->lhs()->variant() != VAR_REF && !expr->lhs()->isInteger())
return false;
return true;
case ADD_OP:
case SUBT_OP:
if (!lhs || !rhs)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (lhs->variant() != VAR_REF && !lhs->isInteger())
return false;
if (rhs->variant() != VAR_REF && !rhs->isInteger())
return false;
if (!((lhs->variant() == VAR_REF) ^ (rhs->variant() == VAR_REF))) // only one variable
return false;
return true;
default:
return false;
}
}
static bool hasEqualVars(SgExpression* firstExpr, SgExpression* secondExpr)
{
set<string> firstVars;
getVariables(firstExpr, firstVars, set<int> { VAR_REF });
set<string> secondVars;
getVariables(secondExpr, secondVars, set<int> { VAR_REF });
if (firstVars.size() != secondVars.size())
return false;
for (const string var : firstVars)
if (secondVars.find(var) == secondVars.end())
return false;
return true;
}
// simple expression: var + varAdd
// varMinus -- is a minus before var
static void getSimpleExprVarParams(SgExpression* expr, bool* varMinus, int* varAdd)
{
if (expr->variant() == VAR_REF)
{
*varMinus = false;
*varAdd = 0;
}
else if (expr->variant() == MINUS_OP && expr->lhs()->variant() == VAR_REF)
{
*varMinus = true;
*varAdd = 0;
}
else if (expr->variant() == ADD_OP)
{
*varMinus = false;
if (expr->lhs()->isInteger())
*varAdd = expr->lhs()->valueInteger();
else
*varAdd = expr->rhs()->valueInteger();
}
else if (expr->variant() == SUBT_OP)
{
if (expr->lhs()->isInteger())
{
*varMinus = true;
*varAdd = expr->lhs()->valueInteger();
}
else
{
*varMinus = false;
*varAdd = (-1) * expr->rhs()->valueInteger();
}
}
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
// returns 1 if equal, 0 if first < second, 2 if first > second
// -1 if impossible to compare
static int compareSimpleExpressions(SgExpression* firstExpr, SgExpression* secondExpr)
{
if (!hasEqualVars(firstExpr, secondExpr))
return -1;
SgConstantSymb* constExpr1 = isSgConstantSymb(firstExpr->symbol());
if (constExpr1)
firstExpr = constExpr1->constantValue();
SgConstantSymb* constExpr2 = isSgConstantSymb(secondExpr->symbol());
if (constExpr2)
secondExpr = constExpr2->constantValue();
if (firstExpr->isInteger() && secondExpr->isInteger())
{
int firstVal = firstExpr->valueInteger();
int secondVal = secondExpr->valueInteger();
if (firstVal == secondVal)
return 1;
else if (firstVal < secondVal)
return 0;
else
return 2;
}
bool minusVarFirst = false, minusVarSecond = false;
int addVarFirst = 0, addVarSecond = 0;
getSimpleExprVarParams(firstExpr, &minusVarFirst, &addVarFirst);
getSimpleExprVarParams(secondExpr, &minusVarSecond, &addVarSecond);
if (minusVarFirst != minusVarSecond) // vars have different sign
return -1;
if (addVarFirst == addVarSecond)
return 1;
else if (addVarFirst < addVarSecond)
return 0;
else
return 2;
}
static bool canBeCombinedWithDiffBounds(const LoopGraph* firstLoop, const LoopGraph* loop)
{
// TODO: удалить после добавления анализа зависимостей по массивам:
return false;
//
if (firstLoop->hasLimitsToCombine() || loop->hasLimitsToCombine())
return false;
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (hasGotoToStatement(firstLoopStmt) || hasGotoToStatement(loopStmt))
return false;
if (!isSimpleExpression(firstLoopStmt->start()) || !isSimpleExpression(firstLoopStmt->end()))
return false;
if (!isSimpleExpression(loopStmt->start()) || !isSimpleExpression(loopStmt->end()))
return false;
if (firstLoop->calculatedCountOfIters != 0 && loop->calculatedCountOfIters != 0)
{
// intersection of ranges is not empty:
if (firstLoop->stepVal > 0)
{
if (firstLoop->startVal > loop->endVal || firstLoop->endVal < loop->startVal)
return false;
}
else
{
if (firstLoop->startVal < loop->endVal || firstLoop->endVal > loop->startVal)
return false;
}
if (firstLoop->stepVal * loop->stepVal < 0) // steps have different sign
return false;
return true;
}
SgExpression* step1 = firstLoopStmt->step();
SgExpression* step2 = loopStmt->step();
int step1Val = 1, step2Val = 1;
if (step1 != NULL)
{
if (step1->isInteger())
step1Val = step1->valueInteger();
else
return false;
}
if (step2 != NULL)
{
if (step2->isInteger())
step2Val = step2->valueInteger();
else
return false;
}
if (step1Val * step2Val < 0) // steps have different sign
return false;
int compStart = compareSimpleExpressions(firstLoopStmt->start(), loopStmt->start());
int compEnd = compareSimpleExpressions(firstLoopStmt->end(), loopStmt->end());
if (compStart == -1 || compEnd == -1) // impossible to compare
return false;
return true;
}
static int getDeepestDimToReverse(LoopGraph* firstLoop, LoopGraph* loop, int perfectLoop,
LoopGraph** toReverse, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph)
{
LoopGraph* curFirstLoop = firstLoop;
LoopGraph* curLoop = loop;
int i = 0;
bool canBeReversed1 = ifLoopCanBeReversed(firstLoop, depInfoForLoopGraph);
bool canBeReversed2 = ifLoopCanBeReversed(loop, depInfoForLoopGraph);
if (!canBeReversed1 && !canBeReversed2)
return 0;
for (i = 0; i < perfectLoop; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(curFirstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
break;
if (curFirstLoop->calculatedCountOfIters != 0 && curLoop->calculatedCountOfIters != 0) {
if (curFirstLoop->startVal != curLoop->endVal)
break;
if (curFirstLoop->endVal != curLoop->startVal)
break;
if (curFirstLoop->stepVal != -1 * curLoop->stepVal)
break;
}
else {
if (!isEqExpressions(std::get<0>(curFirstLoop->startEndStepVals), std::get<1>(curLoop->startEndStepVals)))
break;
if (!isEqExpressions(std::get<1>(curFirstLoop->startEndStepVals), std::get<0>(curLoop->startEndStepVals)))
break;
SgExpression* step1 = std::get<2>(curFirstLoop->startEndStepVals);
SgExpression* step2 = std::get<2>(curLoop->startEndStepVals);
SgValueExp defaultStep(1);
if (step1 == NULL)
step1 = &defaultStep;
if (step2 == NULL)
step2 = &defaultStep;
if (!isOppositeExpressions(step1, step2))
break;
}
if (i != perfectLoop - 1)
{
if (curLoop->children.size() != 1)
break;
curFirstLoop = curFirstLoop->children[0];
curLoop = curLoop->children[0];
}
}
if (i > 0)
{
if (canBeReversed1)
*toReverse = firstLoop;
else
*toReverse = loop;
}
return i;
}
/**
* Найти количество измерений, объединение по которым возможно.
*/
static int getDeepestDimForCombine(const LoopGraph* firstLoop, const LoopGraph* loop, int perfectLoop)
{
const LoopGraph* curFirstLoop = firstLoop;
const LoopGraph* curLoop = loop;
int i = 0;
for (i = 0; i < perfectLoop; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(curFirstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
return i;
if (curFirstLoop->calculatedCountOfIters != 0 && curLoop->calculatedCountOfIters != 0) {
if (curFirstLoop->startVal != curLoop->startVal)
return i;
if (curFirstLoop->endVal != curLoop->endVal)
return i;
if (curFirstLoop->stepVal != curLoop->stepVal)
return i;
}
else {
// startVal:
if (!isEqExpressions(std::get<0>(curFirstLoop->startEndStepVals), std::get<0>(curLoop->startEndStepVals)))
return i;
// endVal:
if (!isEqExpressions(std::get<1>(curFirstLoop->startEndStepVals), std::get<1>(curLoop->startEndStepVals)))
return i;
SgExpression* step1 = std::get<2>(curFirstLoop->startEndStepVals);
SgExpression* step2 = std::get<2>(curLoop->startEndStepVals);
if (!isEqExpressions(step1, step2))
{
if ((step1 == NULL) ^ (step2 == NULL))
{
SgValueExp defaultStep(1);
if (step1 == NULL)
step1 = &defaultStep;
else
step2 = &defaultStep;
if (!isEqExpressions(step1, step2))
return i;
}
else
return i;
}
}
if (i != perfectLoop - 1)
{
curFirstLoop = curFirstLoop->children[0];
curLoop = curLoop->children[0];
}
}
return i;
}
static void compareIterationVars(const LoopGraph* firstLoop, const LoopGraph* loop, int dimensions, map<SgSymbol*, SgSymbol*>& symbols)
{
for (int i = 0; i < dimensions; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (!isEqSymbols(firstLoopStmt->doName(), loopStmt->doName()))
symbols.insert(make_pair(loopStmt->doName(), firstLoopStmt->doName()));
if (i != dimensions - 1)
{
firstLoop = firstLoop->children[0];
loop = loop->children[0];
}
}
}
static SgSymbol* copySymbolAndRename(SgSymbol* symbol)
{
string baseName = symbol->identifier();
size_t pos = baseName.rfind('_');
string strNumber;
int number;
if (pos != string::npos)
{
for (size_t i = pos + 1; i < baseName.length(); ++i)
{
if (baseName[i] >= '0' && baseName[i] <= '9')
strNumber.push_back(baseName[i]);
else
{
strNumber.clear();
break;
}
}
}
if (!strNumber.empty())
{
baseName.resize(baseName.length() - (strNumber.length() + 1));
number = atoi(strNumber.c_str()) + 1;
}
else
number = 1;
int new_name_num = checkSymbNameAndCorrect(baseName + '_', number);
string new_name = baseName + '_' + std::to_string(new_name_num);
SgSymbol* new_sym = &symbol->copy();
new_sym->changeName(new_name.c_str());
return new_sym;
}
static void renameVariables(const map<SgSymbol*, SgSymbol*>& symbols, SgExpression* ex)
{
if (ex)
{
if ((ex->variant() == VAR_REF || isArrayRef(ex)) && ex->symbol())
{
for (auto& pair : symbols)
{
if (isEqSymbols(pair.first, ex->symbol()))
{
ex->setSymbol(pair.second);
break;
}
}
}
renameVariables(symbols, ex->lhs());
renameVariables(symbols, ex->rhs());
}
}
static void renameIterationVariables(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols)
{
if (loop)
{
string& loopName = loop->loopSymbol;
for (auto& pair : symbols)
{
if (pair.first->identifier() == loopName)
{
loop->loopSymbol = (string)pair.second->identifier();
break;
}
}
for (LoopGraph* child : loop->children)
renameIterationVariables(child, symbols);
}
}
static void renameVariablesInLoop(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols)
{
renameIterationVariables(loop, symbols);
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == FOR_NODE)
{
SgForStmt* for_st = (SgForStmt*)st;
for (auto& pair : symbols)
if (isEqSymbols(pair.first, for_st->symbol()))
for_st->setDoName(*pair.second);
}
for (int i = 0; i < 3; ++i)
renameVariables(symbols, st->expr(i));
}
}
static void renamePrivatesInMap(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
auto privates = mapPrivates.find(loop);
if (loop && privates != mapPrivates.end())
{
set<SgSymbol*> newList;
for (auto& priv : privates->second)
{
bool found = false;
for (auto& pair : symbols)
{
if (isEqSymbols(priv, pair.first))
{
found = true;
newList.insert(pair.second);
break;
}
}
if (!found)
newList.insert(priv);
}
privates->second = newList;
for (LoopGraph* child : loop->children)
renamePrivatesInMap(child, symbols, mapPrivates);
}
}
static void addIterationVarsToMap(LoopGraph* loop, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
auto privates = mapPrivates.find(loop);
if (loop && privates != mapPrivates.end())
{
set<SgSymbol*> symbols;
fillIterationVariables(loop, symbols);
for (SgSymbol* var : symbols)
if (!isSymbolInSet(privates->second, var))
privates->second.insert(var);
for (LoopGraph* child : loop->children)
addIterationVarsToMap(child, mapPrivates);
}
}
static void fillMapPrivateVars(const vector<LoopGraph*>& loopGraphs, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
if (loopGraphs.size() == 0)
return;
for (int i = 0; i < loopGraphs.size(); ++i)
{
LoopGraph* loop = loopGraphs[i];
set<Symbol*> symbols;
for (auto& data : getAttributes<SgStatement*, SgStatement*>(loop->loop, set<int>{ SPF_ANALYSIS_DIR }))
fillPrivatesFromComment(new Statement(data), symbols);
set<SgSymbol*> loopPrivates;
for (Symbol* symbol : symbols)
loopPrivates.insert(OriginalSymbol((SgSymbol*)symbol));
mapPrivates.insert(make_pair(loop, loopPrivates));
if (!loop->children.empty())
fillMapPrivateVars(loop->children, mapPrivates);
}
}
static SgForStmt* getInnerLoop(const LoopGraph* loop, int deep)
{
int perfectLoop = loop->perfectLoop;
const LoopGraph* curLoop = loop;
SgForStmt* result = NULL;
for (int i = 0; i < deep; ++i)
{
result = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(result, convertFileName(__FILE__).c_str(), __LINE__);
if (i != perfectLoop - 1)
curLoop = curLoop->children[0];
}
return result;
}
static void moveBody(SgStatement* from, SgStatement* to, const map<SgSymbol*, SgSymbol*>& symbols)
{
for (auto st = from->lexNext(); st != from->lastNodeOfStmt(); st = st->lexNext())
for (int i = 0; i < 3; ++i)
renameVariables(symbols, st->expr(i));
auto loopBody = from->extractStmtBody();
to->lastExecutable()->insertStmtAfter(*loopBody, *to);
}
static SgExpression* createIterationCountExpr(const LoopGraph* loop)
{
// loop: do i = a, b, c
// iteration count expression after loop: (a + ((b - a + c) / c - 1) * c) => [a + 'cIters']
SgForStmt* firstLoopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* a = firstLoopStmt->start();
SgExpression* b = firstLoopStmt->end();
SgExpression* c = firstLoopStmt->step();
if (c == NULL)
c = new SgValueExp(1);
SgExpression* ex = &(*a + ((*b - *a + *c) / *c - *new SgValueExp(1)) * *c);
return ex;
}
static void changeVarToExpr(SgExpression* expression, SgSymbol* var, SgExpression* changeExpr)
{
if (expression == NULL || var == NULL)
return;
SgExpression* lhs = expression->lhs();
SgExpression* rhs = expression->rhs();
if (lhs && lhs->symbol() && isEqSymbols(lhs->symbol(), var))
expression->setLhs(changeExpr);
if (rhs && rhs->symbol() && isEqSymbols(rhs->symbol(), var))
expression->setRhs(changeExpr);
changeVarToExpr(expression->lhs(), var, changeExpr);
changeVarToExpr(expression->rhs(), var, changeExpr);
}
static void changeVarToExpr(SgStatement* statement, SgSymbol* var, SgExpression* expr, int startExpr = 0)
{
if (statement == NULL || var == NULL)
return;
for (int i = startExpr; i < 3; ++i)
{
SgExpression* ex = statement->expr(i);
if (ex && ex->symbol() && isEqSymbols(ex->symbol(), var))
{
statement->setExpression(i, expr);
continue;
}
changeVarToExpr(ex, var, expr);
}
}
static void changeIterationVarToCountExpr(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, SgSymbol* var)
{
SgSymbol* sym = NULL;
for (int dim = 0; dim < dimensions; ++dim)
{
sym = getLoopSymbol(firstLoop);
if (isEqSymbols(sym, var))
break;
firstLoop = firstLoop->children[0];
}
SgExpression* countExpr = createIterationCountExpr(firstLoop);
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
changeVarToExpr(st, var, countExpr);
}
static bool isVarInExpression(SgSymbol* var, SgExpression* ex)
{
bool res = false;
if (ex)
{
if (ex->variant() == VAR_REF || isArrayRef(ex))
if (ex->symbol() && isEqSymbols(ex->symbol(), var))
return true;
res |= isVarInExpression(var, ex->lhs());
res |= isVarInExpression(var, ex->rhs());
}
return res;
}
static bool varIsChanged(SgSymbol* var, LoopGraph* loop)
{
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
return true;
if (st->variant() == FOR_NODE && isEqSymbols(((SgForStmt*)st)->doName(), var))
return true;
}
return false;
}
static bool varIsRead(SgSymbol* var, LoopGraph* loop)
{
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
int i = 0;
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
i = 1;
for (; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
}
return false;
}
static bool varIsChangedBetween(SgSymbol* var, SgStatement* begin, SgStatement* end)
{
for (SgStatement* st = begin; st != end; st = st->lexNext())
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
return true;
return false;
}
static bool varIsUsedBetween(SgSymbol* var, SgStatement* begin, SgStatement* end)
{
if (begin == NULL || end == NULL)
return false;
for (SgStatement* st = begin; st != end; st = st->lexNext())
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
return false;
}
static bool isAntiVarDependency(SgSymbol* var, SgForStmt* loop)
{
bool is_used = false;
for (SgStatement* st = loop; st != loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
{
for (int i = 1; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
return is_used;
}
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
is_used = true;
}
return false;
}
static void replaceIterationVar(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, SgSymbol* var, SgSymbol* newSymbol)
{
LoopGraph* first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* loopSymbol = getLoopSymbol(first);
if (isEqSymbols(loopSymbol, var))
break;
first = first->children[0];
}
SgExpression* countExpr = createIterationCountExpr(first);
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(newSymbol), *countExpr);
firstLoop->loop->insertStmtBefore(*st, *firstLoop->loop->controlParent());
map<SgSymbol*, SgSymbol*> toRename;
toRename.insert(make_pair(var, newSymbol));
renameVariablesInLoop(loop, toRename);
}
static bool varIsReallyNotPrivate(SgSymbol* var, const LoopGraph* loop, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
bool res = false;
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == FOR_NODE)
{
for (LoopGraph* child : loop->children)
{
if (child->loop->id() == st->id())
{
if (isSymbolInSet(mapPrivates[child], var))
res = false;
else
res = varIsReallyNotPrivate(var, child, mapPrivates);
st = st->lastNodeOfStmt();
}
}
}
else
{
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
}
}
return res;
}
static void insertStmtBeforeOuterLoop(SgStatement* st, SgStatement* loop)
{
while (loop->controlParent()->variant() == FOR_NODE)
loop = loop->controlParent();
loop->insertStmtBefore(*st, *loop->controlParent());
}
static void correctInheritedUsage(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, set<SgSymbol*>& firstLoopVars, set<SgSymbol*>& loopVars)
{
set<SgSymbol*> firstLoopIterationVars;
fillIterationVariables(firstLoop, firstLoopIterationVars, dimensions);
set<SgSymbol*> loopIterationVars;
fillIterationVariables(loop, loopIterationVars, dimensions);
LoopGraph* first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* var = getLoopSymbol(first);
if (isSymbolInSet(loopVars, var) && !isSymbolInSet(loopIterationVars, var))
{
if (varIsChanged(var, loop))
{
checkNull(isSgForStmt(loop->loop), convertFileName(__FILE__).c_str(), __LINE__);
if (isAntiVarDependency(var, (SgForStmt*)loop->loop))
{
SgSymbol* newSymbol = copySymbolAndRename(var);
eraseSymbolFromSet(loopVars, var);
loopVars.insert(newSymbol);
makeDeclaration(loop->loop, vector<SgSymbol*> { newSymbol });
replaceIterationVar(first, loop, dimensions, var, newSymbol);
}
}
else
{
eraseSymbolFromSet(loopVars, var);
changeIterationVarToCountExpr(first, loop, dimensions, var);
}
}
if (i != dimensions - 1)
first = first->children[0];
}
// TODO:
// установка значений итерационным переменным, которые в результате объединения заменяются на другие переменные
// временно убрано из прохода
/*first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* loopVar = getLoopSymbol(loop);
SgSymbol* firstLoopVar = getLoopSymbol(first);
if (!isEqSymbols(loopVar, firstLoopVar))
{
SgExpression* countExpr = createIterationCountExpr(loop);
if (isSymbolInSet(firstLoopVars, loopVar))
{
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(loopVar), *countExpr);
firstLoop->loop->insertStmtAfter(*st, *firstLoop->loop->controlParent());
}
else
{
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(loopVar), *countExpr);
firstLoop->loop->insertStmtBefore(*st, *firstLoop->loop->controlParent());
}
}
if (i != dimensions - 1)
{
first = first->children[0];
loop = loop->children[0];
}
}*/
}
// TODO: улучшить анализ зависимостей по массивам
static bool hasDependenciesBetweenArrays(LoopGraph* firstLoop, LoopGraph* loop, int dimensions)
{
set<DIST::Array*> readWriteFrist, readWriteSecond;
vector<pair<LoopGraph*, set<DIST::Array*>*>> loops = { make_pair(firstLoop, &readWriteFrist), make_pair(loop, &readWriteSecond) };
for (auto& loop : loops)
{
const LoopGraph* currLoop = loop.first;
for (int d = 0; d < dimensions; ++d)
{
checkNull(currLoop, convertFileName(__FILE__).c_str(), __LINE__);
*(loop.second) = loop.first->usedArraysAll;
if (currLoop->children.size())
currLoop = currLoop->children[0];
else
currLoop = NULL;
}
}
//есть ли вообще одинаковые массивы, которые читаются и пишутся в объединяемых циклах и отображены на них
set<DIST::Array*> intersect;
std::set_intersection(readWriteFrist.begin(), readWriteFrist.end(), readWriteSecond.begin(), readWriteSecond.end(), inserter(intersect, intersect.begin()));
if (intersect.size() == 0)
return false;
for (auto& array : intersect)
{
const LoopGraph* currLoop[2] = { firstLoop, loop };
for (int d = 0; d < dimensions; ++d)
{
//по измерениям массива отображение на цикл вложенности d
vector<set<pair<int, int>>> coefsRead[2], coefsWrite[2];
checkNull(currLoop[0], convertFileName(__FILE__).c_str(), __LINE__);
checkNull(currLoop[1], convertFileName(__FILE__).c_str(), __LINE__);
for (int k = 0; k < 2; ++k)
{
auto it = currLoop[k]->readOpsForLoop.find(array);
if (it != currLoop[k]->readOpsForLoop.end())
{
if (coefsRead[k].size() == 0)
coefsRead[k].resize(it->second.size());
for (int z = 0; z < it->second.size(); ++z)
if (it->second[z].coefficients.size())
for (auto& coef : it->second[z].coefficients)
coefsRead[k][z].insert(coef.first);
}
auto itW = currLoop[k]->writeOpsForLoop.find(array);
if (itW != currLoop[k]->writeOpsForLoop.end())
{
if (coefsWrite[k].size() == 0)
coefsWrite[k].resize(itW->second.size());
for (int z = 0; z < itW->second.size(); ++z)
if (itW->second[z].coefficients.size())
for (auto& coef : itW->second[z].coefficients)
coefsWrite[k][z].insert(coef.first);
}
}
//нет записей, значит нет зависимости
bool nulWrite = true;
for (auto& wr : coefsWrite)
for (auto& elem : wr)
if (elem.size() != 0)
nulWrite = false;
if (nulWrite)
continue;
// если чтение в одном цикле и запись (и наоборот) в другом идут по разным правилам, то пока что это зависимость.
// здесь можно уточнить.
const int len = std::max(coefsWrite[0].size(), coefsRead[0].size());
int countW[2] = { 0, 0 };
int countR[2] = { 0, 0 };
for (int L = 0; L < 2; ++L)
for (int z = 0; z < coefsWrite[L].size(); ++z)
countW[L] += (coefsWrite[L][z].size() ? 1 : 0);
for (int L = 0; L < 2; ++L)
for (int z = 0; z < coefsRead[L].size(); ++z)
countR[L] += (coefsRead[L][z].size() ? 1 : 0);
for (int p = 0; p < len; ++p)
{
if (coefsWrite[1].size() && coefsWrite[0].size())
if (coefsWrite[0][p].size() != 0 && coefsWrite[1][p].size() != 0)
if (coefsWrite[0][p] != coefsWrite[1][p])
return true;
if (coefsRead[1].size() && coefsWrite[0].size())
if (coefsWrite[0][p].size() != 0 && coefsRead[1][p].size() != 0)
if (coefsWrite[0][p] != coefsRead[1][p])
return true;
if (coefsWrite[1].size() && coefsRead[0].size())
if (coefsWrite[1][p].size() != 0 && coefsRead[0][p].size() != 0)
if (coefsWrite[1][p] != coefsRead[0][p])
return true;
//отображение на разные измерения
if (coefsWrite[1].size() && coefsWrite[0].size())
{
if (coefsWrite[0][p].size() != 0 && coefsWrite[1][p].size() == 0 && countW[1] ||
coefsWrite[0][p].size() == 0 && coefsWrite[1][p].size() != 0 && countW[0])
return true;
}
if (coefsRead[1].size() && coefsWrite[0].size())
{
if (coefsWrite[0][p].size() != 0 && coefsRead[1][p].size() == 0 && countR[1] ||
coefsWrite[0][p].size() == 0 && coefsRead[1][p].size() != 0 && countW[0])
return true;
}
if (coefsWrite[1].size() && coefsRead[1].size())
{
if (coefsWrite[1][p].size() != 0 && coefsRead[0][p].size() == 0 && countR[0] ||
coefsWrite[1][p].size() == 0 && coefsRead[0][p].size() != 0 && countW[1])
return true;
}
//где то нет правил отображения вообще, но есть факт его наличия.
if ( ((coefsWrite[0].size() == 0 && coefsRead[0].size() == 0) && (countW[0] == 0 && countR[0] == 0))
||
((coefsWrite[1].size() == 0 && coefsRead[1].size() == 0) && (countW[1] == 0 && countR[1] == 0)) )
return true;
}
currLoop[0] = (currLoop[0]->children.size()) ? currLoop[0]->children[0] : NULL;
currLoop[1] = (currLoop[1]->children.size()) ? currLoop[1]->children[0] : NULL;
}
}
return false;
}
static int solveVarsCollisions(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
set<SgSymbol*> firstLoopAllVars = getAllVariables<SgSymbol*>(firstLoop->loop, firstLoop->loop->lastNodeOfStmt(), set<int> { VAR_REF, ARRAY_REF });
set<SgSymbol*> loopAllVars = getAllVariables<SgSymbol*>(loop->loop, loop->loop->lastNodeOfStmt(), set<int> { VAR_REF, ARRAY_REF });
if (mapPrivates.find(loop) == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
set<SgSymbol*> loopPrivates = mapPrivates[loop];
if (mapPrivates.find(firstLoop) == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
set<SgSymbol*> firstLoopPrivates = mapPrivates[firstLoop];
set<SgSymbol*> loopIterationVars;
fillIterationVariables(loop, loopIterationVars, dimensions);
for (SgSymbol* var : firstLoopAllVars)
{
if (var->type()->variant() == T_ARRAY || !varIsReallyNotPrivate(var, firstLoop, mapPrivates))
continue;
if (isSymbolInSet(loopPrivates, var) || isSymbolInSet(firstLoopPrivates, var))
continue;
bool isChangedInFirst = false, isChangedInSecond = false;
bool isReadInFirst = false, isReadInSecond = false;
if (isSymbolInSet(loopAllVars, var) && varIsReallyNotPrivate(var, loop, mapPrivates))
{
isChangedInFirst = varIsChanged(var, firstLoop);
isChangedInSecond = varIsChanged(var, loop);
isReadInFirst = varIsRead(var, firstLoop);
isReadInSecond = varIsRead(var, loop);
if (isChangedInFirst && isReadInSecond || isChangedInSecond && isReadInFirst)
return -1;
}
}
if (hasDependenciesBetweenArrays(firstLoop, loop, dimensions))
return -1;
correctInheritedUsage(firstLoop, loop, dimensions, firstLoopAllVars, loopAllVars);
for (SgSymbol* var : loopPrivates)
eraseSymbolFromSet(firstLoopPrivates, var);
for (SgSymbol* var : loopIterationVars)
{
eraseSymbolFromSet(firstLoopPrivates, var);
eraseSymbolFromSet(loopPrivates, var);
LoopGraph* parentLoop = loop;
while (parentLoop)
{
auto pair = mapPrivates.find(parentLoop);
if (pair == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
eraseSymbolFromSet(pair->second, var);
parentLoop = parentLoop->parent;
}
}
vector<SgSymbol*> symbolsToDeclare;
set<SgSymbol*> varsFromLoopToRename;
getIntersection(firstLoopAllVars, loopPrivates, varsFromLoopToRename);
set<SgSymbol*> varsFromFirstLoopToRename;
getIntersection(loopAllVars, firstLoopPrivates, varsFromFirstLoopToRename);
map<SgSymbol*, SgSymbol*> symbolsToRename;
for (SgSymbol* symbol : varsFromLoopToRename)
{
if (varIsReallyNotPrivate(symbol, firstLoop, mapPrivates))
{
SgSymbol* newSymbol = copySymbolAndRename(symbol);
symbolsToDeclare.push_back(newSymbol);
symbolsToRename.insert(make_pair(symbol, newSymbol));
}
}
renamePrivatesInMap(loop, symbolsToRename, mapPrivates);
renameVariablesInLoop(loop, symbolsToRename);
symbolsToRename.clear();
for (SgSymbol* symbol : varsFromFirstLoopToRename)
{
if (varIsReallyNotPrivate(symbol, loop, mapPrivates))
{
SgSymbol* newSymbol = copySymbolAndRename(symbol);
symbolsToDeclare.push_back(newSymbol);
symbolsToRename.insert(make_pair(symbol, newSymbol));
}
}
renamePrivatesInMap(firstLoop, symbolsToRename, mapPrivates);
renameVariablesInLoop(firstLoop, symbolsToRename);
makeDeclaration(symbolsToDeclare, loop->loop->GetOriginal());
LoopGraph* loopToInsert = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
auto pair = mapPrivates.find(loopToInsert);
if (pair == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (SgSymbol* privateVar : mapPrivates[loop])
pair->second.insert(privateVar);
if (i != dimensions - 1)
loopToInsert = loopToInsert->children[0];
}
LoopGraph* toDelete = loop;
for (int i = 0; i < dimensions; ++i)
{
mapPrivates.erase(toDelete);
if (i != dimensions - 1)
toDelete = toDelete->children[0];
}
return 0;
}
static int getNewStep(SgForStmt* firstLoopStmt, SgForStmt* loopStmt)
{
SgExpression* step1 = firstLoopStmt->step();
SgExpression* step2 = loopStmt->step();
int step1Val = 1, step2Val = 1;
if (step1 != NULL)
step1Val = step1->valueInteger();
if (step2 != NULL)
step2Val = step2->valueInteger();
int stepGcd = gcd(std::abs(step1Val), std::abs(step2Val));
int newStep = stepGcd;
int startDifference = 0;
if (firstLoopStmt->start()->isInteger() && loopStmt->start()->isInteger())
{
int start1Val = firstLoopStmt->start()->valueInteger();
int start2Val = loopStmt->start()->valueInteger();
startDifference = std::abs(start1Val - start2Val);
}
else
{
bool var1Minus = false, var2Minus = false;
int var1Add = 0, var2Add = 0;
getSimpleExprVarParams(firstLoopStmt->start(), &var1Minus, &var1Add);
getSimpleExprVarParams(loopStmt->start(), &var2Minus, &var2Add);
startDifference = std::abs(var1Add - var2Add);
}
if (startDifference != 0)
newStep = gcd(startDifference, stepGcd);
if (step1Val < 0)
newStep *= -1;
return newStep;
}
// returns -1 if impossible to get global bounds
static void getGlobalBounds(SgForStmt* firstLoopStmt, SgForStmt* loopStmt, pair<SgExpression*, SgExpression*>& globalBounds)
{
int compStart = compareSimpleExpressions(firstLoopStmt->start(), loopStmt->start());
int compEnd = compareSimpleExpressions(firstLoopStmt->end(), loopStmt->end());
int step = 1;
if (firstLoopStmt->step())
step = firstLoopStmt->step()->valueInteger();
SgExpression* start = NULL, * end = NULL;
if (step > 0)
{
if (compStart == 0) // firstLoopStmt->start() < loopStmt->start()
start = &firstLoopStmt->start()->copy();
else
start = &loopStmt->start()->copy();
if (compEnd == 2) // firstLoopStmt->end() > loopStmt->end()
end = &firstLoopStmt->end()->copy();
else
end = &loopStmt->end()->copy();
}
else
{
if (compStart == 2) // firstLoopStmt->start() > loopStmt->start()
start = &firstLoopStmt->start()->copy();
else
start = &loopStmt->start()->copy();
if (compEnd == 0) // firstLoopStmt->end() < loopStmt->end()
end = &firstLoopStmt->end()->copy();
else
end = &loopStmt->end()->copy();
}
globalBounds = make_pair(start, end);
}
static SgStatement* makeIfStatementForBounds(SgForStmt* loopStmt, const pair<SgExpression*, SgExpression*>& globalBounds,
SgSymbol* loopSymbol, int newStep)
{
SgExpression* step = NULL;
int stepVal = 1;
if (loopStmt->step() != NULL)
{
step = &loopStmt->step()->copy();
stepVal = loopStmt->step()->valueInteger();
}
else
step = new SgValueExp(1);
SgExpression* stepCond = NULL;
if (stepVal != newStep)
{
// MOD(var - start, step) .eq. 0
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
SgExpression* subt = new SgExpression(SUBT_OP, varRef, &loopStmt->start()->copy());
vector<SgExpression*> vec = { step, subt };
SgExpression* list = makeExprList(vec, false);
SgSymbol* symbol = new SgSymbol(FUNCTION_NAME, "mod");
SgExpression* mod = new SgExpression(FUNC_CALL, list, NULL, symbol);
stepCond = new SgExpression(EQ_OP, mod, new SgValueExp(0));
}
SgExpression* startCond = NULL;
if (!isEqExpressions(loopStmt->start(), globalBounds.first))
{
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
if (stepVal > 0)
startCond = new SgExpression(GTEQL_OP, varRef, &loopStmt->start()->copy());
else
startCond = new SgExpression(LTEQL_OP, varRef, &loopStmt->start()->copy());
}
SgExpression* endCond = NULL;
if (!isEqExpressions(loopStmt->end(), globalBounds.second))
{
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
if (stepVal > 0)
endCond = new SgExpression(LTEQL_OP, varRef, &loopStmt->end()->copy());
else
endCond = new SgExpression(GTEQL_OP, varRef, &loopStmt->end()->copy());
}
SgExpression* loopCond = NULL;
if (startCond)
loopCond = startCond;
if (endCond)
{
if (loopCond)
loopCond = new SgExpression(AND_OP, loopCond, endCond);
else
loopCond = endCond;
}
if (stepCond)
{
if (loopCond)
loopCond = new SgExpression(AND_OP, loopCond, stepCond);
else
loopCond = stepCond;
}
SgIfStmt* ifStmt = NULL;
if (loopCond)
ifStmt = new SgIfStmt(*loopCond);
return ifStmt;
}
static void moveBodyWithDiffBounds(SgForStmt* from, SgForStmt* to, const pair<SgExpression*, SgExpression*>& globalBounds, int newStep)
{
map<SgSymbol*, SgSymbol*> symbols;
symbols.insert(make_pair(from->doName(), to->doName()));
SgStatement* ifStmt = makeIfStatementForBounds(from, globalBounds, to->doName(), newStep);
if (ifStmt)
{
to->lastExecutable()->insertStmtAfter(*ifStmt, *to);
moveBody(from, ifStmt, symbols);
}
else
moveBody(from, to, symbols);
}
static void moveCommentsAndAttributes(SgStatement* loopFrom, SgStatement* loopTo)
{
if (loopFrom->comments())
loopTo->addComment(string(loopFrom->comments()).c_str());
if (loopFrom->numberOfAttributes())
{
auto data = getAttributes<SgStatement*, SgStatement*>(loopFrom, set<int>{ SPF_ANALYSIS_DIR });
for (auto& elem : data)
loopTo->addAttribute(SPF_ANALYSIS_DIR, elem, sizeof(SgStatement*));
}
}
static void combineWithDifferentBounds(const LoopGraph* firstLoop, const LoopGraph* loop)
{
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoop, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* step1 = firstLoopStmt->step();
int step1Val = 1;
if (step1 != NULL)
step1Val = step1->valueInteger();
pair<SgExpression*, SgExpression*> globalBounds;
getGlobalBounds(firstLoopStmt, loopStmt, globalBounds);
int newStep = getNewStep(firstLoopStmt, loopStmt);
SgStatement* firstLoopIfStmt = makeIfStatementForBounds(firstLoopStmt, globalBounds, firstLoopStmt->doName(), newStep);
firstLoopStmt->setStart(*globalBounds.first);
firstLoopStmt->setEnd(*globalBounds.second);
if (newStep != step1Val)
firstLoopStmt->setStep(*new SgValueExp(newStep));
if (firstLoopIfStmt)
{
map<SgSymbol*, SgSymbol*> symbols;
moveBody(firstLoopStmt, firstLoopIfStmt, symbols);
firstLoopStmt->lastExecutable()->insertStmtAfter(*firstLoopIfStmt, *firstLoopStmt);
}
moveBodyWithDiffBounds(loopStmt, firstLoopStmt, globalBounds, newStep);
moveCommentsAndAttributes(loopStmt, firstLoopStmt);
loopStmt->extractStmt();
}
/**
* Собственно объединение
*/
static bool combine(LoopGraph* firstLoop, const vector<LoopGraph*>& nextLoops, set<LoopGraph*>& combinedLoops,
map<LoopGraph*, set<SgSymbol*>>& mapPrivates, vector<Messages>& messages, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
bool wasCombine = false;
for (LoopGraph* loop : nextLoops)
{
if (!loop->isFor)
return wasCombine;
int perfectLoop = std::min(firstLoop->perfectLoop, loop->perfectLoop);
const LoopGraph* curLoop = firstLoop;
for (int i = 0; i < perfectLoop; ++i)
{
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
return false;
if (i != perfectLoop - 1)
curLoop = curLoop->children[0];
}
map<SgSymbol*, SgSymbol*> symbolsFromLoopToRename;
int dimensionsForCombine = getDeepestDimForCombine(firstLoop, loop, perfectLoop);
LoopGraph* loopToReverse = NULL;
if (dimensionsForCombine == 0)
dimensionsForCombine = getDeepestDimToReverse(firstLoop, loop, perfectLoop, &loopToReverse, depInfoForLoopGraph);
if (dimensionsForCombine || canBeCombinedWithDiffBounds(firstLoop, loop))
{
if (solveVarsCollisions(firstLoop, loop, dimensionsForCombine, mapPrivates) == -1)
break;
if (dimensionsForCombine)
{
reverseLoop(loopToReverse, dimensionsForCombine);
compareIterationVars(firstLoop, loop, dimensionsForCombine, symbolsFromLoopToRename);
SgForStmt* innerMainLoop = getInnerLoop(firstLoop, dimensionsForCombine);
moveBody(getInnerLoop(loop, dimensionsForCombine), innerMainLoop, symbolsFromLoopToRename);
moveCommentsAndAttributes(loop->loop, firstLoop->loop);
loop->loop->extractStmt();
}
else
{
dimensionsForCombine = 1;
combineWithDifferentBounds(firstLoop, loop);
}
combinedLoops.insert(loop);
wasCombine = true;
//move in structure
LoopGraph* deep = loop, *parent = firstLoop;
for (int p = 0; p < dimensionsForCombine - 1; ++p)
{
deep = deep->children[0];
parent = parent->children[0];
}
for (auto& toMove : deep->children)
{
parent->children.push_back(toMove);
toMove->parent = parent;
}
deep->children.clear();
firstLoop->recalculatePerfect();
wstring strR, strE;
__spf_printToLongBuf(strE, L"Loops on line %d and on line %d were combined", firstLoop->lineNum, loop->lineNum);
__spf_printToLongBuf(strR, R100, firstLoop->lineNum, loop->lineNum);
messages.push_back(Messages(NOTE, firstLoop->lineNum, strR, strE, 2005));
__spf_print(1, "Loops on lines %d and %d were combined\n", firstLoop->lineNum, loop->lineNum);
countOfTransform++;
}
else
break;
}
return wasCombine;
}
/**
* Возвращает следующие loopsAmount циклов после nextAfterThis.
* Если loopsAmount < 0, вернёт все последующие циклы, до первого оператора-не-цикла.
*/
static vector<LoopGraph*> getNextLoops(LoopGraph* nextAfterThis, vector<LoopGraph*>& loops, int loopsAmount)
{
vector<LoopGraph*> result;
SgStatement* lastSt = nextAfterThis->loop->lastNodeOfStmt();
int z = 0;
for (; z < loops.size(); ++z)
if (loops[z] == nextAfterThis)
break;
if (z == loops.size())
return result;
else
z++;
for (; z < loops.size(); ++z)
{
if (loopsAmount == 0)
break;
SgStatement* loopSt = loops[z]->loop->GetOriginal();
if (lastSt->lexNext() != loopSt)
break;
else
{
lastSt = loopSt->lastNodeOfStmt();
result.push_back(loops[z]);
--loopsAmount;
}
}
return result;
}
static bool tryToCombine(vector<LoopGraph*>& loopGraphs, map<LoopGraph*, set<SgSymbol*>>& mapPrivates,
vector<Messages>& messages, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
if (loopGraphs.size() == 0)
return false;
bool change = false;
set<LoopGraph*> loopsToDelete;
vector<LoopGraph*> newloopGraphs;
vector<LoopGraph*> loops = loopGraphs;
for (size_t z = 0; z < loops.size(); ++z)
{
LoopGraph* loop = loops[z];
newloopGraphs.push_back(loop);
if (!loop->isFor)
continue;
vector<LoopGraph*> nextLoops = getNextLoops(loop, loopGraphs, -1);
set<LoopGraph*> combinedLoops;
change = false;
if (nextLoops.size())
change = combine(loop, nextLoops, combinedLoops, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
for (LoopGraph* combined : combinedLoops)
{
loopsToDelete.insert(combined);
loopGraphs.erase(find(loopGraphs.begin(), loopGraphs.end(), combined));
}
if (change)
{
LoopGraph* loopParent = loop;
while (loopParent->parent)
loopParent = loopParent->parent;
addIterationVarsToMap(loopParent, mapPrivates);
LoopGraph* outerParent = loop;
while (outerParent->parent)
outerParent = outerParent->parent;
outerParent->recalculatePerfect();
}
z += combinedLoops.size();
}
loopGraphs = newloopGraphs;
for (LoopGraph* elem : loopsToDelete)
delete elem;
if (change == false)
{
for (LoopGraph* ch : loopGraphs)
{
bool res = tryToCombine(ch->children, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
change |= res;
}
}
return change;
}
int combineLoops(SgFile* file, vector<LoopGraph*>& loopGraphs, vector<Messages>& messages,
const pair<string, int>& onPlace, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
map<int, LoopGraph*> mapGraph;
createMapLoopGraph(loopGraphs, mapGraph);
map<LoopGraph*, set<SgSymbol*>> mapPrivates;
fillMapPrivateVars(loopGraphs, mapPrivates);
if (onPlace.second > 0)
{
if (onPlace.first != file->filename())
return 0;
else
{
const int onLine = onPlace.second;
auto it = mapGraph.find(onLine);
if (it == mapGraph.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
vector<LoopGraph*> nextLoops = getNextLoops(it->second, it->second->parent ? it->second->parent->children : loopGraphs, 1);
set<LoopGraph*> combinedLoops;
bool wasCombine = false;
if (nextLoops.size())
wasCombine = combine(it->second, nextLoops, combinedLoops, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
return 0;
}
}
bool change = true;
int count = 0;
while (change)
{
change = tryToCombine(loopGraphs, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
if (change)
count++;
}
/*printf(" === \n");
for (auto& elem : mapPrivates)
{
printf("for loop %d\n", elem.first->lineNum);
for (auto& priv : elem.second)
printf(" %s\n", priv->identifier());
}*/
return 0;
}