#include "loops_unrolling.h" #include "../LoopAnalyzer/loop_analyzer.h" #include "../Utils/errors.h" #include "../GraphLoop/graph_loops_func.h" #include #include #include using std::string; using std::wstring; using std::vector; using std::set; static void replaceSymbToValue(SgExpression* ex, const string& symb, const int value) { if (ex) { if (ex->lhs()) { SgExpression* left = ex->lhs(); if (left->variant() == VAR_REF && OriginalSymbol(left->symbol())->identifier() == symb) ex->setLhs(new SgValueExp(value)); } if (ex->rhs()) { SgExpression* right = ex->rhs(); if (right->variant() == VAR_REF && OriginalSymbol(right->symbol())->identifier() == symb) ex->setRhs(new SgValueExp(value)); } replaceSymbToValue(ex->lhs(), symb, value); replaceSymbToValue(ex->rhs(), symb, value); } } static bool ifHasConstruct(SgStatement* copy, vector& messages) { bool onlyOne = (copy->lastNodeOfStmt() == copy); SgStatement* last = copy->lastNodeOfStmt()->lexNext(); SgStatement* st = copy; bool retVal = false; do { if (st->variant() == IF_NODE || st->variant() == SWITCH_NODE || st->variant() == CYCLE_STMT || st->variant() == EXIT_STMT || st->variant() == CASE_NODE || st->variant() == DEFAULT_NODE || st->variant() == CONTROL_END) { if (BIF_SYMB(st->thebif)) { wstring messageE, messageR; __spf_printToLongBuf(messageE, L"This operator has construct-name, so it does not allowed for UNROLL transformation"); __spf_printToLongBuf(messageR, R195); messages.push_back(Messages(ERROR, st->lineNumber(), messageR, messageE, 2015)); retVal = true; } } if (st->variant() == FOR_NODE) { if (isSgForStmt(st)->constructName()) { wstring messageE, messageR; __spf_printToLongBuf(messageE, L"This operator has construct-name, so it does not allowed for UNROLL transformation"); __spf_printToLongBuf(messageR, R195); messages.push_back(Messages(ERROR, st->lineNumber(), messageR, messageE, 2015)); retVal = true; } } if (onlyOne) break; st = st->lexNext(); } while (st != last); return retVal; } int unrollLoops(SgFile* file, vector& loopGraph, vector& messages) { int err = 0; for (auto& loop : loopGraph) { if (loop->children.size()) err += unrollLoops(NULL, loop->children, messages); if (err > 0) break; auto attrsTr = getAttributes(loop->loop->GetOriginal(), set{ SPF_TRANSFORM_DIR }); for (auto attr : attrsTr) { SgExpression* list = attr->expr(0); if (list->lhs()->variant() == SPF_UNROLL_OP) { int range[3] = { 0, 0, 0 }; bool inited = false; if (list->lhs()->lhs())// if with range { SgExprListExp* listExp = isSgExprListExp(list->lhs()->lhs()); checkNull(listExp, convertFileName(__FILE__).c_str(), __LINE__); int len = listExp->length(); if (len != 3) { wstring messageE, messageR; __spf_printToLongBuf(messageE, L"wrong directive syntax"); __spf_printToLongBuf(messageR, R185); messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015)); err = 1; break; } inited = true; bool isFalse = !listExp->elem(0)->isInteger() || !listExp->elem(1)->isInteger() || !listExp->elem(2)->isInteger(); if (!isFalse) { loop->startVal = listExp->elem(0)->valueInteger(); loop->endVal = listExp->elem(1)->valueInteger(); loop->stepVal = listExp->elem(2)->valueInteger(); std::tuple tmp; loop->calculatedCountOfIters = calculateLoopIters(listExp->elem(0), listExp->elem(1), listExp->elem(2), tmp); } else { wstring messageE, messageR; __spf_printToLongBuf(messageE, L"wrong directive syntax - expression must be evaluated"); __spf_printToLongBuf(messageR, R186); messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015)); err = 1; break; } } if (loop->calculatedCountOfIters == 0) // unknown { wstring messageE, messageR; __spf_printToLongBuf(messageE, L"expression must be evaluated"); __spf_printToLongBuf(messageR, R187); messages.push_back(Messages(ERROR, loop->lineNum, messageR, messageE, 2015)); err = 1; break; } int unroll = loop->calculatedCountOfIters; SgForStmt* currLoop = (SgForStmt*)loop->loop; SgStatement* last = currLoop->lastNodeOfStmt(); const string loopS = OriginalSymbol(currLoop->symbol())->identifier(); vector orig; SgStatement* body = currLoop->body(); while (body != last) { orig.push_back(body); body = body->lastNodeOfStmt(); body = body->lexNext(); } SgStatement* insertBefore = last->lexNext(); SgStatement* cp = insertBefore->controlParent(); for (int z = 0, val = loop->startVal; z < unroll; ++z, val += loop->stepVal) { for (auto& elem : orig) { SgStatement* copyS = elem->copyPtr(); if (ifHasConstruct(elem, messages)) err = 1; SgStatement* lastS = copyS->lastNodeOfStmt(); if (copyS == lastS) lastS = lastS->lexNext(); SgStatement* st = copyS; do { for (int k = 0; k < 3; ++k) replaceSymbToValue(st->expr(k), loopS, val); st = st->lexNext(); } while (st != lastS); insertBefore->insertStmtBefore(*copyS, *cp); } } currLoop->extractStmt(); break; } } } return err; }