diff --git a/src/Transformations/SwapOperators/swap_operators.cpp b/src/Transformations/SwapOperators/swap_operators.cpp index b5cba8c..47565fa 100644 --- a/src/Transformations/SwapOperators/swap_operators.cpp +++ b/src/Transformations/SwapOperators/swap_operators.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "../../Utils/errors.h" #include "../../Utils/SgUtils.h" @@ -15,319 +16,206 @@ using namespace std; +string getNameByArg(SAPFOR::Argument* arg); +SgSymbol* getSybolByArg(SAPFOR::Argument* arg); -unordered_set loop_tags = {FOR_NODE/*, FORALL_NODE, WHILE_NODE, DO_WHILE_NODE*/}; -unordered_set importantDepsTags = {FOR_NODE, IF_NODE}; -unordered_set importantUpdDepsTags = {ELSEIF_NODE}; -unordered_set importantEndTags = {CONTROL_END}; - - -vector findInstructionsFromOperator(SgStatement* st, vector Blocks) -{ +static vector findInstructionsFromOperator(SgStatement* st, vector Blocks) { vector result; - string filename = st -> fileName(); - for (auto& block: Blocks) - { - vector instructionsInBlock = block -> getInstructions(); - for (auto& instruction: instructionsInBlock) - { - SgStatement* curOperator = instruction -> getInstruction() -> getOperator(); - if (curOperator -> lineNumber() == st -> lineNumber()) + string filename = st->fileName(); + + for (auto& block: Blocks) { + vector instructionsInBlock = block->getInstructions(); + for (auto& instruction: instructionsInBlock) { + SgStatement* curOperator = instruction->getInstruction()->getOperator(); + // Match by line number to find corresponding IR instruction + if (curOperator->lineNumber() == st->lineNumber()) { result.push_back(instruction); + } } } return result; } -vector findFuncBlocksByFuncStatement(SgStatement *st, map>& FullIR) -{ - vector result; - Statement* forSt = (Statement*)st; - for (auto& func: FullIR) - { - if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile() - && func.first -> funcPointer -> lineNumber() == forSt -> lineNumber()) - result = func.second; - } - return result; -} +unordered_set loop_tags = {FOR_NODE}; +unordered_set control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE}; +unordered_set control_end_tags = {CONTROL_END}; -map> findAndAnalyzeLoops(SgStatement *st, vector blocks) -{ - map> result; - SgStatement *lastNode = st->lastNodeOfStmt(); - while (st && st != lastNode) - { - if (loop_tags.find(st -> variant()) != loop_tags.end()) - { - // part with find statements of loop - SgForStmt *forSt = (SgForStmt*)st; - SgStatement *loopBody = forSt -> body(); - SgStatement *lastLoopNode = st->lastNodeOfStmt(); - // part with find blocks and instructions of loops - unordered_set blocks_nums; - while (loopBody && loopBody != lastLoopNode) - { - SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front(); - if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) - { - result[forSt].push_back(IR -> getBasicBlock()); - blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); - } - loopBody = loopBody -> lexNext(); - } - std::sort(result[forSt].begin(), result[forSt].end()); - } - st = st -> lexNext(); - } - return result; -} - -map> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector loopBlocks, map>& FullIR) -{ - map> result; - for (SAPFOR::BasicBlock* bb: loopBlocks) - { - map> blockReachingDefinitions = bb -> getRD_In(); - vector instructions = bb -> getInstructions(); - for (SAPFOR::IR_Block* irBlock: instructions) - { - // TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account - SAPFOR::Instruction* instr = irBlock -> getInstruction(); - result[instr -> getOperator()]; - // take Argument 1 and it's RD and push operators to final set - if (instr -> getArg1() != NULL) - { - SAPFOR::Argument* arg = instr -> getArg1(); - set prevInstructionsNumbers = blockReachingDefinitions[arg]; - for (int i: prevInstructionsNumbers) - { - SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first; - if (foundInstruction != NULL) - { - SgStatement* prevOp = foundInstruction -> getOperator(); - if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber() - && prevOp -> lineNumber() > forStatement -> lineNumber()) - result[instr -> getOperator()].insert(prevOp); - } - } - } - // take Argument 2 (if exists) and it's RD and push operators to final set - if (instr -> getArg2() != NULL) - { - SAPFOR::Argument* arg = instr -> getArg2(); - set prevInstructionsNumbers = blockReachingDefinitions[arg]; - for (int i: prevInstructionsNumbers) - { - SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first; - if (foundInstruction != NULL) - { - SgStatement* prevOp = foundInstruction -> getOperator(); - if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber() - && prevOp -> lineNumber() > forStatement -> lineNumber()) - result[instr -> getOperator()].insert(prevOp); - } - } - } - // update RD - if (instr -> getResult() != NULL) - blockReachingDefinitions[instr -> getResult()] = {instr -> getNumber()}; - } - } - return result; -} - -void buildAdditionalDeps(SgForStmt* forStatement, map>& dependencies) -{ - SgStatement* lastNode = forStatement->lastNodeOfStmt(); - vector importantDeps; - SgStatement* st = (SgStatement*) forStatement; - st = st -> lexNext(); - SgStatement* logIfOp = NULL; - while (st && st != lastNode) - { - if(importantDeps.size() != 0) - { - if (st != importantDeps.back()) - { - dependencies[st].insert(importantDeps.back()); - } - } - if (logIfOp != NULL) - { - dependencies[st].insert(logIfOp); - logIfOp = NULL; - } - if (st -> variant() == LOGIF_NODE) - { - logIfOp = st; - } - if (importantDepsTags.find(st -> variant()) != importantDepsTags.end()) - { - importantDeps.push_back(st); - } - if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end()) - { - importantDeps.pop_back(); - importantDeps.push_back(st); - } - if (importantEndTags.find(st -> variant()) != importantEndTags.end()) - { - if(importantDeps.size() != 0) - { - importantDeps.pop_back(); - } - } - st = st -> lexNext(); - } -} - -struct ReadyOp { +struct OperatorInfo { SgStatement* stmt; - int degree; - size_t arrival; - ReadyOp(SgStatement* s, int d, size_t a): stmt(s), degree(d), arrival(a) {} + set usedVars; + set definedVars; + int lineNumber; + bool isMovable; + OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {} }; -struct ReadyOpCompare { - bool operator()(const ReadyOp& a, const ReadyOp& b) const { - if (a.degree != b.degree) - return a.degree > b.degree; - else - return a.arrival > b.arrival; - } -}; - -vector scheduleOperations(const map>& dependencies) -{ - // get all statements - unordered_set allStmtsSet; - for (const auto& pair : dependencies) - { - allStmtsSet.insert(pair.first); - for (SgStatement* dep : pair.second) - { - allStmtsSet.insert(dep); - } - } - vector allStmts(allStmtsSet.begin(), allStmtsSet.end()); - // count deps and build reversed graph - unordered_map> graph; - unordered_map inDegree; - unordered_map degree; - for (auto op : allStmts) - inDegree[op] = 0; - // find and remember initial dependencies - unordered_set dependentStmts; - for (const auto& pair : dependencies) - { - SgStatement* op = pair.first; - const auto& deps = pair.second; - degree[op] = deps.size(); - inDegree[op] = deps.size(); - if (!deps.empty()) - dependentStmts.insert(op); - for (auto dep : deps) - graph[dep].push_back(op); - } - for (SgStatement* op : allStmts) - { - if (!degree.count(op)) - { - degree[op] = 0; - } - } - // build queues - using PQ = priority_queue, ReadyOpCompare>; - PQ readyDependent; - queue readyIndependent; - size_t arrivalCounter = 0; - for (auto op : allStmts) - { - if (inDegree[op] == 0) - { - if (dependentStmts.count(op)) - { - readyDependent.emplace(op, degree[op], arrivalCounter++); - } - else - { - readyIndependent.push(op); - } - } - } - // main sort algorythm - vector executionOrder; - while (!readyDependent.empty() || !readyIndependent.empty()) - { - SgStatement* current = nullptr; - if (!readyDependent.empty()) - { - current = readyDependent.top().stmt; - readyDependent.pop(); - } - else - { - current = readyIndependent.front(); - readyIndependent.pop(); - } - executionOrder.push_back(current); - for (SgStatement* neighbor : graph[current]) - { - inDegree[neighbor]--; - if (inDegree[neighbor] == 0) { - if (dependentStmts.count(neighbor)) - { - readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++); - } - else - { - readyIndependent.push(neighbor); - } - } - } - } - return executionOrder; -} - -static bool buildNewAST(SgStatement* loop, vector& newBody) -{ - if (!loop) {return false;} - if (newBody.empty()) {return true;} - if (loop->variant() != FOR_NODE) {return false;} - +static vector analyzeOperatorsInLoop(SgForStmt* loop, vector blocks, map>& FullIR) { + vector operators; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); - if (!loopStart || !loopEnd) {return false;} + + SgStatement* current = loopStart; + while (current && current != loopEnd) { + if (isSgExecutableStatement(current)) { + OperatorInfo opInfo(current); + + vector irBlocks = findInstructionsFromOperator(current, blocks); + for (auto irBlock : irBlocks) { + SAPFOR::Instruction* instr = irBlock->getInstruction(); + + if (instr->getArg1()) { + string varName = getNameByArg(instr->getArg1()); + if (!varName.empty()) { + opInfo.usedVars.insert(varName); + } + } + if (instr->getArg2()) { + string varName = getNameByArg(instr->getArg2()); + if (!varName.empty()) { + opInfo.usedVars.insert(varName); + } + } + if (instr->getResult()) { + string varName = getNameByArg(instr->getResult()); + if (!varName.empty()) { + opInfo.definedVars.insert(varName); + } + } + } + + if (control_tags.find(current->variant()) != control_tags.end()) { + opInfo.isMovable = false; + } + + operators.push_back(opInfo); + } + current = current->lexNext(); + } + + return operators; +} - for (SgStatement* stmt : newBody) { - if (stmt && stmt != loop && stmt != loopEnd) { - SgStatement* current = loopStart; - bool found = false; - while (current && current != loopEnd->lexNext()) { - if (current == stmt) { - found = true; +static map> findVariableDefinitions(SgForStmt* loop, vector& operators) { + map> varDefinitions; + + for (auto& op : operators) { + for (const string& var : op.definedVars) { + varDefinitions[var].push_back(op.stmt); + } + } + + return varDefinitions; +} + +static int calculateDistance(SgStatement* from, SgStatement* to) { + if (!from || !to) return INT_MAX; + return abs(to->lineNumber() - from->lineNumber()); +} + +static SgStatement* findBestPosition(SgStatement* operatorStmt, vector& operators, map>& varDefinitions) { + OperatorInfo* opInfo = nullptr; + for (auto& op : operators) { + if (op.stmt == operatorStmt) { + opInfo = &op; + break; + } + } + + if (!opInfo || !opInfo->isMovable) return nullptr; + + SgStatement* bestPos = nullptr; + int minDistance = INT_MAX; + + for (const string& usedVar : opInfo->usedVars) { + if (varDefinitions.find(usedVar) != varDefinitions.end()) { + for (SgStatement* defStmt : varDefinitions[usedVar]) { + int distance = calculateDistance(operatorStmt, defStmt); + if (distance < minDistance) { + minDistance = distance; + bestPos = defStmt; + } + } + } + } + + return bestPos; +} + +static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) { + if (!from || !to || from == to) return false; + + SgStatement* loopStart = loop->lexNext(); + SgStatement* loopEnd = loop->lastNodeOfStmt(); + + if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) { + return false; + } + + SgStatement* current = from; + while (current && current != loopEnd) { + if (control_tags.find(current->variant()) != control_tags.end()) { + return false; + } + if (current == to) break; + current = current->lexNext(); + } + + return true; +} + +static vector optimizeOperatorOrder(SgForStmt* loop, vector& operators, map>& varDefinitions) { + vector newOrder; + vector moved(operators.size(), false); + + for (size_t i = 0; i < operators.size(); i++) { + if (moved[i] || !operators[i].isMovable) { + newOrder.push_back(operators[i].stmt); + moved[i] = true; + continue; + } + + SgStatement* bestPos = findBestPosition(operators[i].stmt, operators, varDefinitions); + + if (bestPos && canMoveTo(operators[i].stmt, bestPos, loop)) { + bool inserted = false; + for (size_t j = 0; j < newOrder.size(); j++) { + if (newOrder[j] == bestPos) { + newOrder.insert(newOrder.begin() + j + 1, operators[i].stmt); + inserted = true; break; } - current = current->lexNext(); } - if (!found) {return false;} + if (!inserted) { + newOrder.push_back(operators[i].stmt); + } + } else { + newOrder.push_back(operators[i].stmt); } + moved[i] = true; } + + return newOrder; +} +static bool applyOperatorReordering(SgForStmt* loop, vector& newOrder) { + if (!loop || newOrder.empty()) return false; + + SgStatement* loopStart = loop->lexNext(); + SgStatement* loopEnd = loop->lastNodeOfStmt(); + vector extractedStatements; vector savedComments; - vector savedLineNumbers; - - for (SgStatement* stmt : newBody) { + + for (SgStatement* stmt : newOrder) { if (stmt && stmt != loop && stmt != loopEnd) { savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); - savedLineNumbers.push_back(stmt->lineNumber()); SgStatement* extracted = stmt->extractStmt(); - if (extracted) {extractedStatements.push_back(extracted);} + if (extracted) { + extractedStatements.push_back(extracted); + } } } - + SgStatement* currentPos = loop; int lineCounter = loop->lineNumber() + 1; @@ -342,13 +230,13 @@ static bool buildNewAST(SgStatement* loop, vector& newBody) currentPos = stmt; } } - + for (char* comment : savedComments) { if (comment) { free(comment); } } - + if (currentPos && currentPos->lexNext() != loopEnd) { currentPos->setLexNext(*loopEnd); } @@ -356,67 +244,56 @@ static bool buildNewAST(SgStatement* loop, vector& newBody) return true; } -static bool validateNewOrder(SgStatement* loop, const vector& newOrder) -{ - if (!loop || newOrder.empty()) { - return true; +vector findFuncBlocksByFuncStatement(SgStatement *st, map>& FullIR) { + vector result; + Statement* forSt = (Statement*)st; + for (auto& func: FullIR) { + if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile() + && func.first -> funcPointer -> lineNumber() == forSt -> lineNumber()) + result = func.second; } - unordered_set seen; - for (SgStatement* stmt : newOrder) { - if (stmt && stmt != loop && stmt != loop->lastNodeOfStmt()) { - if (seen.count(stmt)) { - return false; - } - seen.insert(stmt); - } - } - return true; + return result; } -void runSwapOperators(SgFile *file, std::map>& loopGraph, std::map>& FullIR, int& countOfTransform) -{ - std::cout << "SWAP_OPERATORS Pass" << std::endl; // to remove - countOfTransform += 1; // to remove +map> findAndAnalyzeLoops(SgStatement *st, vector blocks) { + map> result; + SgStatement *lastNode = st->lastNodeOfStmt(); + while (st && st != lastNode) { + if (loop_tags.find(st -> variant()) != loop_tags.end()) { + SgForStmt *forSt = (SgForStmt*)st; + SgStatement *loopBody = forSt -> body(); + SgStatement *lastLoopNode = st->lastNodeOfStmt(); + unordered_set blocks_nums; + while (loopBody && loopBody != lastLoopNode) { + SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front(); + if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) { + result[forSt].push_back(IR -> getBasicBlock()); + blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); + } + loopBody = loopBody -> lexNext(); + } + std::sort(result[forSt].begin(), result[forSt].end()); + } + st = st -> lexNext(); + } + return result; +} + +void runSwapOperators(SgFile *file, std::map>& loopGraph, std::map>& FullIR, int& countOfTransform) { + countOfTransform += 1; const int funcNum = file -> numberOfFunctions(); - for (int i = 0; i < funcNum; ++i) - { + for (int i = 0; i < funcNum; ++i) { SgStatement *st = file -> functions(i); vector blocks = findFuncBlocksByFuncStatement(st, FullIR); map> loopsMapping = findAndAnalyzeLoops(st, blocks); - for (pair> loopForAnalyze: loopsMapping) - { - map> dependencyGraph = AnalyzeLoopAndFindDeps(loopForAnalyze.first, loopForAnalyze.second, FullIR); - // TODO: Write a function that will go through the operators and update all dependencies so that there are no mix-ups and splits inside the semantic blocks (for if, do and may be some other cases) - buildAdditionalDeps(loopForAnalyze.first, dependencyGraph); - cout << endl; - int firstLine = loopForAnalyze.first -> lineNumber(); - int lastLine = loopForAnalyze.first -> lastNodeOfStmt() -> lineNumber(); - cout << "LOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES" << endl; - // for (auto &v: dependencyGraph) { - // cout << "OPERATOR: " << v.first -> lineNumber() << " " << v.first -> variant() << "\nDEPENDS ON:" << endl; - // if (v.second.size() != 0) - // for (auto vv: v.second) - // cout << vv -> lineNumber() << " "; - // cout << endl; - // } - vector new_order = scheduleOperations(dependencyGraph); - cout << "RESULT ORDER:" << endl; - for (auto v: new_order) - if (v -> lineNumber() > firstLine) - cout << v -> lineNumber() << endl; - - if (validateNewOrder(loopForAnalyze.first, new_order)) { - buildNewAST(loopForAnalyze.first, new_order); - } - st = loopForAnalyze.first -> lexNext(); - while (st != loopForAnalyze.first -> lastNodeOfStmt()) - { - cout << st -> lineNumber() << " " << st -> sunparse() << endl; - st = st -> lexNext(); - } + + for (pair> loopForAnalyze: loopsMapping) { + vector operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR); + map> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators); + + vector newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions); + applyOperatorReordering(loopForAnalyze.first, newOrder); } } - - return; -}; +} \ No newline at end of file