213 lines
7.4 KiB
C++
213 lines
7.4 KiB
C++
|
|
#include "loops_unrolling.h"
|
||
|
|
|
||
|
|
#include "../LoopAnalyzer/loop_analyzer.h"
|
||
|
|
#include "../Utils/errors.h"
|
||
|
|
#include "../GraphLoop/graph_loops_func.h"
|
||
|
|
|
||
|
|
#include <string>
|
||
|
|
#include <vector>
|
||
|
|
#include <set>
|
||
|
|
|
||
|
|
using std::string;
|
||
|
|
using std::wstring;
|
||
|
|
using std::vector;
|
||
|
|
using std::set;
|
||
|
|
|
||
|
|
static void replaceSymbToValue(SgExpression* ex, const string& symb, const int value)
|
||
|
|
{
|
||
|
|
if (ex)
|
||
|
|
{
|
||
|
|
if (ex->lhs())
|
||
|
|
{
|
||
|
|
SgExpression* left = ex->lhs();
|
||
|
|
if (left->variant() == VAR_REF && OriginalSymbol(left->symbol())->identifier() == symb)
|
||
|
|
ex->setLhs(new SgValueExp(value));
|
||
|
|
}
|
||
|
|
|
||
|
|
if (ex->rhs())
|
||
|
|
{
|
||
|
|
SgExpression* right = ex->rhs();
|
||
|
|
if (right->variant() == VAR_REF && OriginalSymbol(right->symbol())->identifier() == symb)
|
||
|
|
ex->setRhs(new SgValueExp(value));
|
||
|
|
}
|
||
|
|
|
||
|
|
replaceSymbToValue(ex->lhs(), symb, value);
|
||
|
|
replaceSymbToValue(ex->rhs(), symb, value);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
static bool ifHasConstruct(SgStatement* copy, vector<Messages>& messages)
|
||
|
|
{
|
||
|
|
bool onlyOne = (copy->lastNodeOfStmt() == copy);
|
||
|
|
SgStatement* last = copy->lastNodeOfStmt()->lexNext();
|
||
|
|
|
||
|
|
SgStatement* st = copy;
|
||
|
|
bool retVal = false;
|
||
|
|
do
|
||
|
|
{
|
||
|
|
if (st->variant() == IF_NODE ||
|
||
|
|
st->variant() == SWITCH_NODE ||
|
||
|
|
st->variant() == CYCLE_STMT ||
|
||
|
|
st->variant() == EXIT_STMT ||
|
||
|
|
st->variant() == CASE_NODE ||
|
||
|
|
st->variant() == DEFAULT_NODE ||
|
||
|
|
st->variant() == CONTROL_END)
|
||
|
|
{
|
||
|
|
if (BIF_SYMB(st->thebif))
|
||
|
|
{
|
||
|
|
wstring messageE, messageR;
|
||
|
|
__spf_printToLongBuf(messageE, L"This operator has construct-name, so it does not allowed for UNROLL transformation");
|
||
|
|
__spf_printToLongBuf(messageR, R195);
|
||
|
|
|
||
|
|
messages.push_back(Messages(ERROR, st->lineNumber(), messageR, messageE, 2015));
|
||
|
|
retVal = true;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if (st->variant() == FOR_NODE)
|
||
|
|
{
|
||
|
|
if (isSgForStmt(st)->constructName())
|
||
|
|
{
|
||
|
|
wstring messageE, messageR;
|
||
|
|
__spf_printToLongBuf(messageE, L"This operator has construct-name, so it does not allowed for UNROLL transformation");
|
||
|
|
__spf_printToLongBuf(messageR, R195);
|
||
|
|
|
||
|
|
messages.push_back(Messages(ERROR, st->lineNumber(), messageR, messageE, 2015));
|
||
|
|
retVal = true;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if (onlyOne)
|
||
|
|
break;
|
||
|
|
st = st->lexNext();
|
||
|
|
} while (st != last);
|
||
|
|
|
||
|
|
return retVal;
|
||
|
|
}
|
||
|
|
|
||
|
|
int unrollLoops(SgFile* file, vector<LoopGraph*>& loopGraph, vector<Messages>& messages)
|
||
|
|
{
|
||
|
|
int err = 0;
|
||
|
|
|
||
|
|
for (auto& loop : loopGraph)
|
||
|
|
{
|
||
|
|
if (loop->children.size())
|
||
|
|
err += unrollLoops(NULL, loop->children, messages);
|
||
|
|
|
||
|
|
if (err > 0)
|
||
|
|
break;
|
||
|
|
|
||
|
|
auto attrsTr = getAttributes<SgStatement*, SgStatement*>(loop->loop->GetOriginal(), set<int>{ SPF_TRANSFORM_DIR });
|
||
|
|
for (auto attr : attrsTr)
|
||
|
|
{
|
||
|
|
SgExpression* list = attr->expr(0);
|
||
|
|
if (list->lhs()->variant() == SPF_UNROLL_OP)
|
||
|
|
{
|
||
|
|
int range[3] = { 0, 0, 0 };
|
||
|
|
bool inited = false;
|
||
|
|
if (list->lhs()->lhs())// if with range
|
||
|
|
{
|
||
|
|
SgExprListExp* listExp = isSgExprListExp(list->lhs()->lhs());
|
||
|
|
checkNull(listExp, convertFileName(__FILE__).c_str(), __LINE__);
|
||
|
|
|
||
|
|
int len = listExp->length();
|
||
|
|
if (len != 3)
|
||
|
|
{
|
||
|
|
wstring messageE, messageR;
|
||
|
|
__spf_printToLongBuf(messageE, L"wrong directive syntax");
|
||
|
|
__spf_printToLongBuf(messageR, R185);
|
||
|
|
|
||
|
|
messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015));
|
||
|
|
|
||
|
|
err = 1;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
inited = true;
|
||
|
|
|
||
|
|
bool isFalse = !listExp->elem(0)->isInteger() || !listExp->elem(1)->isInteger() || !listExp->elem(2)->isInteger();
|
||
|
|
|
||
|
|
if (!isFalse)
|
||
|
|
{
|
||
|
|
loop->startVal = listExp->elem(0)->valueInteger();
|
||
|
|
loop->endVal = listExp->elem(1)->valueInteger();
|
||
|
|
loop->stepVal = listExp->elem(2)->valueInteger();
|
||
|
|
|
||
|
|
std::tuple<int, int, int> tmp;
|
||
|
|
loop->calculatedCountOfIters = calculateLoopIters(listExp->elem(0), listExp->elem(1), listExp->elem(2), tmp);
|
||
|
|
}
|
||
|
|
else
|
||
|
|
{
|
||
|
|
wstring messageE, messageR;
|
||
|
|
__spf_printToLongBuf(messageE, L"wrong directive syntax - expression must be evaluated");
|
||
|
|
__spf_printToLongBuf(messageR, R186);
|
||
|
|
|
||
|
|
messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015));
|
||
|
|
err = 1;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if (loop->calculatedCountOfIters == 0) // unknown
|
||
|
|
{
|
||
|
|
wstring messageE, messageR;
|
||
|
|
__spf_printToLongBuf(messageE, L"expression must be evaluated");
|
||
|
|
__spf_printToLongBuf(messageR, R187);
|
||
|
|
|
||
|
|
messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015));
|
||
|
|
err = 1;
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
int unroll = loop->calculatedCountOfIters;
|
||
|
|
|
||
|
|
SgForStmt* currLoop = (SgForStmt*)loop->loop;
|
||
|
|
|
||
|
|
SgStatement* last = currLoop->lastNodeOfStmt();
|
||
|
|
const string loopS = OriginalSymbol(currLoop->symbol())->identifier();
|
||
|
|
vector<SgStatement*> orig;
|
||
|
|
|
||
|
|
SgStatement* body = currLoop->body();
|
||
|
|
while (body != last)
|
||
|
|
{
|
||
|
|
orig.push_back(body);
|
||
|
|
|
||
|
|
body = body->lastNodeOfStmt();
|
||
|
|
body = body->lexNext();
|
||
|
|
}
|
||
|
|
|
||
|
|
SgStatement* insertBefore = last->lexNext();
|
||
|
|
SgStatement* cp = insertBefore->controlParent();
|
||
|
|
for (int z = 0, val = loop->startVal; z < unroll; ++z, val += loop->stepVal)
|
||
|
|
{
|
||
|
|
for (auto& elem : orig)
|
||
|
|
{
|
||
|
|
SgStatement* copyS = elem->copyPtr();
|
||
|
|
if (ifHasConstruct(elem, messages))
|
||
|
|
err = 1;
|
||
|
|
|
||
|
|
SgStatement* lastS = copyS->lastNodeOfStmt();
|
||
|
|
if (copyS == lastS)
|
||
|
|
lastS = lastS->lexNext();
|
||
|
|
|
||
|
|
SgStatement* st = copyS;
|
||
|
|
do
|
||
|
|
{
|
||
|
|
for (int k = 0; k < 3; ++k)
|
||
|
|
replaceSymbToValue(st->expr(k), loopS, val);
|
||
|
|
|
||
|
|
st = st->lexNext();
|
||
|
|
} while (st != lastS);
|
||
|
|
|
||
|
|
insertBefore->insertStmtBefore(*copyS, *cp);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
currLoop->extractStmt();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
return err;
|
||
|
|
}
|