Adding handing of nested loops and conditional statements

This commit is contained in:
Egor Mayorov
2025-12-11 01:41:58 +03:00
parent e9d5a2ee70
commit a6e6af7577

View File

@@ -17,7 +17,6 @@
using namespace std;
string getNameByArg(SAPFOR::Argument* arg);
SgSymbol* getSybolByArg(SAPFOR::Argument* arg);
static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks) {
vector<SAPFOR::IR_Block*> result;
@@ -36,9 +35,9 @@ static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, v
return result;
}
unordered_set<int> loop_tags = {FOR_NODE};
unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE};
unordered_set<int> control_end_tags = {CONTROL_END};
unordered_set<int> loop_tags = {FOR_NODE}; // Loop statements
unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE, LOGIF_NODE}; // Control structures that cannot be moved
unordered_set<int> control_end_tags = {CONTROL_END}; // End marker
struct OperatorInfo {
SgStatement* stmt;
@@ -46,21 +45,170 @@ struct OperatorInfo {
set<string> definedVars;
int lineNumber;
bool isMovable;
OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {}
};
static bool isStatementEmbedded(SgStatement* stmt, SgStatement* parent) {
if (!stmt || !parent || stmt == parent) return false;
if (parent->variant() == LOGIF_NODE) {
if (stmt->lineNumber() == parent->lineNumber()) {
return true;
}
SgStatement* current = parent;
SgStatement* lastNode = parent->lastNodeOfStmt();
while (current && current != lastNode) {
if (current == stmt) {
return true;
}
if (current->isIncludedInStmt(*stmt)) {
return true;
}
current = current->lexNext();
}
}
if (parent->isIncludedInStmt(*stmt)) {
return true;
}
return false;
}
static bool isLoopBoundary(SgStatement* stmt) {
if (!stmt) return false;
if (stmt->variant() == FOR_NODE || stmt->variant() == CONTROL_END) {
return true;
}
return false;
}
static bool isPartOfNestedLoop(SgStatement* stmt, SgForStmt* loop) {
if (!stmt || !loop) return false;
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
if (stmt->lineNumber() < loopStart->lineNumber() || stmt->lineNumber() > loopEnd->lineNumber()) {
return false;
}
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (current->variant() == FOR_NODE && current != loop) {
SgForStmt* nestedLoop = (SgForStmt*)current;
SgStatement* nestedStart = nestedLoop->lexNext();
SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt();
if (nestedStart && nestedEnd &&
stmt->lineNumber() >= nestedStart->lineNumber() &&
stmt->lineNumber() <= nestedEnd->lineNumber()) {
return true;
}
}
current = current->lexNext();
}
return false;
}
static bool canSafelyExtract(SgStatement* stmt, SgForStmt* loop) {
if (!stmt || !loop) return false;
if (isLoopBoundary(stmt)) {
return false;
}
if (control_tags.find(stmt->variant()) != control_tags.end()) {
return false;
}
if (isPartOfNestedLoop(stmt, loop)) {
return false;
}
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (current->variant() == LOGIF_NODE && current->lineNumber() == stmt->lineNumber()) {
return false;
}
if (control_tags.find(current->variant()) != control_tags.end()) {
if (isStatementEmbedded(stmt, current)) {
return false;
}
}
if (current == stmt) break;
current = current->lexNext();
}
return true;
}
static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFOR::BasicBlock*> blocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) {
vector<OperatorInfo> operators;
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) {
return operators;
}
SgStatement* current = loopStart;
unordered_set<SgStatement*> visited;
while (current && current != loopEnd) {
if (visited.find(current) != visited.end()) {
break;
}
visited.insert(current);
if (isLoopBoundary(current)) {
current = current->lexNext();
continue;
}
if (current->variant() == FOR_NODE && current != loop) {
SgStatement* nestedEnd = current->lastNodeOfStmt();
if (nestedEnd) {
current = nestedEnd->lexNext();
} else {
current = current->lexNext();
}
continue;
}
if (isSgExecutableStatement(current)) {
if (control_tags.find(current->variant()) != control_tags.end()) {
current = current->lexNext();
continue;
}
if (current->variant() != ASSIGN_STAT) {
current = current->lexNext();
continue;
}
OperatorInfo opInfo(current);
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(current, blocks);
for (auto irBlock : irBlocks) {
if (!irBlock || !irBlock->getInstruction()) continue;
SAPFOR::Instruction* instr = irBlock->getInstruction();
if (instr->getArg1()) {
@@ -83,10 +231,6 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
}
}
if (control_tags.find(current->variant()) != control_tags.end()) {
opInfo.isMovable = false;
}
operators.push_back(opInfo);
}
current = current->lexNext();
@@ -97,13 +241,11 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
static map<string, vector<SgStatement*>> findVariableDefinitions(SgForStmt* loop, vector<OperatorInfo>& operators) {
map<string, vector<SgStatement*>> varDefinitions;
for (auto& op : operators) {
for (const string& var : op.definedVars) {
varDefinitions[var].push_back(op.stmt);
}
}
return varDefinitions;
}
@@ -121,7 +263,9 @@ static SgStatement* findBestPosition(SgStatement* operatorStmt, vector<OperatorI
}
}
if (!opInfo || !opInfo->isMovable) return nullptr;
if (!opInfo || !opInfo->isMovable) {
return nullptr;
}
SgStatement* bestPos = nullptr;
int minDistance = INT_MAX;
@@ -147,12 +291,22 @@ static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) {
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) {
return false;
}
SgStatement* current = from;
unordered_set<SgStatement*> visited;
while (current && current != loopEnd) {
if (visited.find(current) != visited.end()) {
return false;
}
visited.insert(current);
if (control_tags.find(current->variant()) != control_tags.end()) {
return false;
}
@@ -203,29 +357,98 @@ static bool applyOperatorReordering(SgForStmt* loop, vector<SgStatement*>& newOr
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
vector<SgStatement*> originalOrder;
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (isSgExecutableStatement(current) && current->variant() == ASSIGN_STAT) {
originalOrder.push_back(current);
}
current = current->lexNext();
}
bool orderChanged = false;
if (originalOrder.size() == newOrder.size()) {
for (size_t i = 0; i < originalOrder.size(); i++) {
if (originalOrder[i] != newOrder[i]) {
orderChanged = true;
break;
}
}
} else {
orderChanged = true;
}
if (!orderChanged) {
return false;
}
vector<SgStatement*> extractedStatements;
vector<char*> savedComments;
unordered_set<SgStatement*> extractedSet;
map<SgStatement*, int> originalLineNumbers;
for (SgStatement* stmt : newOrder) {
if (stmt && stmt != loop && stmt != loopEnd) {
if (stmt && stmt != loop && stmt != loopEnd && extractedSet.find(stmt) == extractedSet.end()) {
if (control_tags.find(stmt->variant()) != control_tags.end()) {
continue;
}
if (!canSafelyExtract(stmt, loop)) {
continue;
}
bool isMoving = false;
for (size_t i = 0; i < originalOrder.size(); i++) {
if (originalOrder[i] == stmt) {
for (size_t j = 0; j < newOrder.size(); j++) {
if (newOrder[j] == stmt && i != j) {
isMoving = true;
break;
}
}
break;
}
}
if (!isMoving) {
continue;
}
originalLineNumbers[stmt] = stmt->lineNumber();
savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr);
SgStatement* extracted = stmt->extractStmt();
if (extracted) {
extractedStatements.push_back(extracted);
extractedSet.insert(stmt);
}
}
}
SgStatement* currentPos = loop;
int lineCounter = loop->lineNumber() + 1;
for (size_t i = 0; i < extractedStatements.size(); i++) {
SgStatement* stmt = extractedStatements[i];
if (stmt) {
SgStatement* nextPos = currentPos->lexNext();
if (nextPos && nextPos != loopEnd) {
if (nextPos->variant() == FOR_NODE && nextPos != loop) {
continue;
}
if (nextPos->variant() == CONTROL_END) {
continue;
}
}
if (i < savedComments.size() && savedComments[i]) {
stmt->setComments(savedComments[i]);
}
stmt->setlineNumber(lineCounter++);
if (originalLineNumbers.find(stmt) != originalLineNumbers.end()) {
stmt->setlineNumber(originalLineNumbers[stmt]);
}
currentPos->insertStmtAfter(*stmt, *loop);
currentPos = stmt;
}
@@ -258,17 +481,24 @@ vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<F
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) {
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
SgStatement *lastNode = st->lastNodeOfStmt();
while (st && st != lastNode) {
if (loop_tags.find(st -> variant()) != loop_tags.end()) {
SgForStmt *forSt = (SgForStmt*)st;
SgStatement *loopBody = forSt -> body();
SgStatement *lastLoopNode = st->lastNodeOfStmt();
unordered_set<int> blocks_nums;
while (loopBody && loopBody != lastLoopNode) {
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front();
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(loopBody, blocks);
if (!irBlocks.empty()) {
SAPFOR::IR_Block* IR = irBlocks.front();
if (IR && IR->getBasicBlock()) {
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
}
}
}
loopBody = loopBody -> lexNext();
}
@@ -281,19 +511,25 @@ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) {
countOfTransform += 1;
std::cout << "SWAP_OPERATORS Pass Started" << std::endl;
const int funcNum = file -> numberOfFunctions();
for (int i = 0; i < funcNum; ++i) {
SgStatement *st = file -> functions(i);
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping) {
vector<OperatorInfo> operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR);
map<string, vector<SgStatement*>> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators);
vector<SgStatement*> newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions);
applyOperatorReordering(loopForAnalyze.first, newOrder);
}
}
std::cout << "SWAP_OPERATORS Pass Completed" << std::endl;
}