diff --git a/src/Transformations/SwapOperators/swap_operators.cpp b/src/Transformations/SwapOperators/swap_operators.cpp index 162624f..b5cba8c 100644 --- a/src/Transformations/SwapOperators/swap_operators.cpp +++ b/src/Transformations/SwapOperators/swap_operators.cpp @@ -290,22 +290,88 @@ vector scheduleOperations(const map& newBody) +static bool buildNewAST(SgStatement* loop, vector& newBody) { - SgStatement* endDo = loop->lastNodeOfStmt(); - SgStatement* st = loop; - int lineNum = loop -> lineNumber() + 1; - for (int i = 0; i < newBody.size(); i++) - { - st -> setLexNext(*newBody[i]); - st = st -> lexNext(); - st -> setlineNumber(lineNum); - lineNum++; + if (!loop) {return false;} + if (newBody.empty()) {return true;} + if (loop->variant() != FOR_NODE) {return false;} + + SgStatement* loopStart = loop->lexNext(); + SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) {return false;} + + for (SgStatement* stmt : newBody) { + if (stmt && stmt != loop && stmt != loopEnd) { + SgStatement* current = loopStart; + bool found = false; + while (current && current != loopEnd->lexNext()) { + if (current == stmt) { + found = true; + break; + } + current = current->lexNext(); + } + if (!found) {return false;} + } } - st -> setLexNext(*endDo); + + vector extractedStatements; + vector savedComments; + vector savedLineNumbers; + + for (SgStatement* stmt : newBody) { + if (stmt && stmt != loop && stmt != loopEnd) { + savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); + savedLineNumbers.push_back(stmt->lineNumber()); + SgStatement* extracted = stmt->extractStmt(); + if (extracted) {extractedStatements.push_back(extracted);} + } + } + + SgStatement* currentPos = loop; + int lineCounter = loop->lineNumber() + 1; + + for (size_t i = 0; i < extractedStatements.size(); i++) { + SgStatement* stmt = extractedStatements[i]; + if (stmt) { + if (i < savedComments.size() && savedComments[i]) { + stmt->setComments(savedComments[i]); + } + stmt->setlineNumber(lineCounter++); + currentPos->insertStmtAfter(*stmt, *loop); + currentPos = stmt; + } + } + + for (char* comment : savedComments) { + if (comment) { + free(comment); + } + } + + if (currentPos && currentPos->lexNext() != loopEnd) { + currentPos->setLexNext(*loopEnd); + } + + return true; } - +static bool validateNewOrder(SgStatement* loop, const vector& newOrder) +{ + if (!loop || newOrder.empty()) { + return true; + } + unordered_set seen; + for (SgStatement* stmt : newOrder) { + if (stmt && stmt != loop && stmt != loop->lastNodeOfStmt()) { + if (seen.count(stmt)) { + return false; + } + seen.insert(stmt); + } + } + return true; +} void runSwapOperators(SgFile *file, std::map>& loopGraph, std::map>& FullIR, int& countOfTransform) { @@ -339,7 +405,10 @@ void runSwapOperators(SgFile *file, std::map lineNumber() > firstLine) cout << v -> lineNumber() << endl; - buildNewAST(loopForAnalyze.first, new_order); + + if (validateNewOrder(loopForAnalyze.first, new_order)) { + buildNewAST(loopForAnalyze.first, new_order); + } st = loopForAnalyze.first -> lexNext(); while (st != loopForAnalyze.first -> lastNodeOfStmt()) {