Adding handing of nested loops and conditional statements

This commit is contained in:
Egor Mayorov
2025-12-11 01:41:58 +03:00
parent e9d5a2ee70
commit a6e6af7577

View File

@@ -17,7 +17,6 @@
using namespace std; using namespace std;
string getNameByArg(SAPFOR::Argument* arg); string getNameByArg(SAPFOR::Argument* arg);
SgSymbol* getSybolByArg(SAPFOR::Argument* arg);
static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks) { static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks) {
vector<SAPFOR::IR_Block*> result; vector<SAPFOR::IR_Block*> result;
@@ -36,9 +35,9 @@ static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, v
return result; return result;
} }
unordered_set<int> loop_tags = {FOR_NODE}; unordered_set<int> loop_tags = {FOR_NODE}; // Loop statements
unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE}; unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE, LOGIF_NODE}; // Control structures that cannot be moved
unordered_set<int> control_end_tags = {CONTROL_END}; unordered_set<int> control_end_tags = {CONTROL_END}; // End marker
struct OperatorInfo { struct OperatorInfo {
SgStatement* stmt; SgStatement* stmt;
@@ -46,21 +45,170 @@ struct OperatorInfo {
set<string> definedVars; set<string> definedVars;
int lineNumber; int lineNumber;
bool isMovable; bool isMovable;
OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {} OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {}
}; };
static bool isStatementEmbedded(SgStatement* stmt, SgStatement* parent) {
if (!stmt || !parent || stmt == parent) return false;
if (parent->variant() == LOGIF_NODE) {
if (stmt->lineNumber() == parent->lineNumber()) {
return true;
}
SgStatement* current = parent;
SgStatement* lastNode = parent->lastNodeOfStmt();
while (current && current != lastNode) {
if (current == stmt) {
return true;
}
if (current->isIncludedInStmt(*stmt)) {
return true;
}
current = current->lexNext();
}
}
if (parent->isIncludedInStmt(*stmt)) {
return true;
}
return false;
}
static bool isLoopBoundary(SgStatement* stmt) {
if (!stmt) return false;
if (stmt->variant() == FOR_NODE || stmt->variant() == CONTROL_END) {
return true;
}
return false;
}
static bool isPartOfNestedLoop(SgStatement* stmt, SgForStmt* loop) {
if (!stmt || !loop) return false;
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
if (stmt->lineNumber() < loopStart->lineNumber() || stmt->lineNumber() > loopEnd->lineNumber()) {
return false;
}
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (current->variant() == FOR_NODE && current != loop) {
SgForStmt* nestedLoop = (SgForStmt*)current;
SgStatement* nestedStart = nestedLoop->lexNext();
SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt();
if (nestedStart && nestedEnd &&
stmt->lineNumber() >= nestedStart->lineNumber() &&
stmt->lineNumber() <= nestedEnd->lineNumber()) {
return true;
}
}
current = current->lexNext();
}
return false;
}
static bool canSafelyExtract(SgStatement* stmt, SgForStmt* loop) {
if (!stmt || !loop) return false;
if (isLoopBoundary(stmt)) {
return false;
}
if (control_tags.find(stmt->variant()) != control_tags.end()) {
return false;
}
if (isPartOfNestedLoop(stmt, loop)) {
return false;
}
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (current->variant() == LOGIF_NODE && current->lineNumber() == stmt->lineNumber()) {
return false;
}
if (control_tags.find(current->variant()) != control_tags.end()) {
if (isStatementEmbedded(stmt, current)) {
return false;
}
}
if (current == stmt) break;
current = current->lexNext();
}
return true;
}
static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFOR::BasicBlock*> blocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) { static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFOR::BasicBlock*> blocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) {
vector<OperatorInfo> operators; vector<OperatorInfo> operators;
SgStatement* loopStart = loop->lexNext(); SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt(); SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) {
return operators;
}
SgStatement* current = loopStart; SgStatement* current = loopStart;
unordered_set<SgStatement*> visited;
while (current && current != loopEnd) { while (current && current != loopEnd) {
if (visited.find(current) != visited.end()) {
break;
}
visited.insert(current);
if (isLoopBoundary(current)) {
current = current->lexNext();
continue;
}
if (current->variant() == FOR_NODE && current != loop) {
SgStatement* nestedEnd = current->lastNodeOfStmt();
if (nestedEnd) {
current = nestedEnd->lexNext();
} else {
current = current->lexNext();
}
continue;
}
if (isSgExecutableStatement(current)) { if (isSgExecutableStatement(current)) {
if (control_tags.find(current->variant()) != control_tags.end()) {
current = current->lexNext();
continue;
}
if (current->variant() != ASSIGN_STAT) {
current = current->lexNext();
continue;
}
OperatorInfo opInfo(current); OperatorInfo opInfo(current);
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(current, blocks); vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(current, blocks);
for (auto irBlock : irBlocks) { for (auto irBlock : irBlocks) {
if (!irBlock || !irBlock->getInstruction()) continue;
SAPFOR::Instruction* instr = irBlock->getInstruction(); SAPFOR::Instruction* instr = irBlock->getInstruction();
if (instr->getArg1()) { if (instr->getArg1()) {
@@ -83,10 +231,6 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
} }
} }
if (control_tags.find(current->variant()) != control_tags.end()) {
opInfo.isMovable = false;
}
operators.push_back(opInfo); operators.push_back(opInfo);
} }
current = current->lexNext(); current = current->lexNext();
@@ -97,13 +241,11 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
static map<string, vector<SgStatement*>> findVariableDefinitions(SgForStmt* loop, vector<OperatorInfo>& operators) { static map<string, vector<SgStatement*>> findVariableDefinitions(SgForStmt* loop, vector<OperatorInfo>& operators) {
map<string, vector<SgStatement*>> varDefinitions; map<string, vector<SgStatement*>> varDefinitions;
for (auto& op : operators) { for (auto& op : operators) {
for (const string& var : op.definedVars) { for (const string& var : op.definedVars) {
varDefinitions[var].push_back(op.stmt); varDefinitions[var].push_back(op.stmt);
} }
} }
return varDefinitions; return varDefinitions;
} }
@@ -121,7 +263,9 @@ static SgStatement* findBestPosition(SgStatement* operatorStmt, vector<OperatorI
} }
} }
if (!opInfo || !opInfo->isMovable) return nullptr; if (!opInfo || !opInfo->isMovable) {
return nullptr;
}
SgStatement* bestPos = nullptr; SgStatement* bestPos = nullptr;
int minDistance = INT_MAX; int minDistance = INT_MAX;
@@ -147,12 +291,22 @@ static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) {
SgStatement* loopStart = loop->lexNext(); SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt(); SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) { if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) {
return false; return false;
} }
SgStatement* current = from; SgStatement* current = from;
unordered_set<SgStatement*> visited;
while (current && current != loopEnd) { while (current && current != loopEnd) {
if (visited.find(current) != visited.end()) {
return false;
}
visited.insert(current);
if (control_tags.find(current->variant()) != control_tags.end()) { if (control_tags.find(current->variant()) != control_tags.end()) {
return false; return false;
} }
@@ -203,29 +357,98 @@ static bool applyOperatorReordering(SgForStmt* loop, vector<SgStatement*>& newOr
SgStatement* loopStart = loop->lexNext(); SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt(); SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) return false;
vector<SgStatement*> originalOrder;
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (isSgExecutableStatement(current) && current->variant() == ASSIGN_STAT) {
originalOrder.push_back(current);
}
current = current->lexNext();
}
bool orderChanged = false;
if (originalOrder.size() == newOrder.size()) {
for (size_t i = 0; i < originalOrder.size(); i++) {
if (originalOrder[i] != newOrder[i]) {
orderChanged = true;
break;
}
}
} else {
orderChanged = true;
}
if (!orderChanged) {
return false;
}
vector<SgStatement*> extractedStatements; vector<SgStatement*> extractedStatements;
vector<char*> savedComments; vector<char*> savedComments;
unordered_set<SgStatement*> extractedSet;
map<SgStatement*, int> originalLineNumbers;
for (SgStatement* stmt : newOrder) { for (SgStatement* stmt : newOrder) {
if (stmt && stmt != loop && stmt != loopEnd) { if (stmt && stmt != loop && stmt != loopEnd && extractedSet.find(stmt) == extractedSet.end()) {
if (control_tags.find(stmt->variant()) != control_tags.end()) {
continue;
}
if (!canSafelyExtract(stmt, loop)) {
continue;
}
bool isMoving = false;
for (size_t i = 0; i < originalOrder.size(); i++) {
if (originalOrder[i] == stmt) {
for (size_t j = 0; j < newOrder.size(); j++) {
if (newOrder[j] == stmt && i != j) {
isMoving = true;
break;
}
}
break;
}
}
if (!isMoving) {
continue;
}
originalLineNumbers[stmt] = stmt->lineNumber();
savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr);
SgStatement* extracted = stmt->extractStmt(); SgStatement* extracted = stmt->extractStmt();
if (extracted) { if (extracted) {
extractedStatements.push_back(extracted); extractedStatements.push_back(extracted);
extractedSet.insert(stmt);
} }
} }
} }
SgStatement* currentPos = loop; SgStatement* currentPos = loop;
int lineCounter = loop->lineNumber() + 1;
for (size_t i = 0; i < extractedStatements.size(); i++) { for (size_t i = 0; i < extractedStatements.size(); i++) {
SgStatement* stmt = extractedStatements[i]; SgStatement* stmt = extractedStatements[i];
if (stmt) { if (stmt) {
SgStatement* nextPos = currentPos->lexNext();
if (nextPos && nextPos != loopEnd) {
if (nextPos->variant() == FOR_NODE && nextPos != loop) {
continue;
}
if (nextPos->variant() == CONTROL_END) {
continue;
}
}
if (i < savedComments.size() && savedComments[i]) { if (i < savedComments.size() && savedComments[i]) {
stmt->setComments(savedComments[i]); stmt->setComments(savedComments[i]);
} }
stmt->setlineNumber(lineCounter++);
if (originalLineNumbers.find(stmt) != originalLineNumbers.end()) {
stmt->setlineNumber(originalLineNumbers[stmt]);
}
currentPos->insertStmtAfter(*stmt, *loop); currentPos->insertStmtAfter(*stmt, *loop);
currentPos = stmt; currentPos = stmt;
} }
@@ -258,17 +481,24 @@ vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<F
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) { map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) {
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result; map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
SgStatement *lastNode = st->lastNodeOfStmt(); SgStatement *lastNode = st->lastNodeOfStmt();
while (st && st != lastNode) { while (st && st != lastNode) {
if (loop_tags.find(st -> variant()) != loop_tags.end()) { if (loop_tags.find(st -> variant()) != loop_tags.end()) {
SgForStmt *forSt = (SgForStmt*)st; SgForStmt *forSt = (SgForStmt*)st;
SgStatement *loopBody = forSt -> body(); SgStatement *loopBody = forSt -> body();
SgStatement *lastLoopNode = st->lastNodeOfStmt(); SgStatement *lastLoopNode = st->lastNodeOfStmt();
unordered_set<int> blocks_nums; unordered_set<int> blocks_nums;
while (loopBody && loopBody != lastLoopNode) { while (loopBody && loopBody != lastLoopNode) {
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front(); vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(loopBody, blocks);
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) { if (!irBlocks.empty()) {
result[forSt].push_back(IR -> getBasicBlock()); SAPFOR::IR_Block* IR = irBlocks.front();
blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); if (IR && IR->getBasicBlock()) {
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
}
}
} }
loopBody = loopBody -> lexNext(); loopBody = loopBody -> lexNext();
} }
@@ -281,19 +511,25 @@ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) { void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) {
countOfTransform += 1; countOfTransform += 1;
std::cout << "SWAP_OPERATORS Pass Started" << std::endl;
const int funcNum = file -> numberOfFunctions(); const int funcNum = file -> numberOfFunctions();
for (int i = 0; i < funcNum; ++i) { for (int i = 0; i < funcNum; ++i) {
SgStatement *st = file -> functions(i); SgStatement *st = file -> functions(i);
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR); vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks); map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping) { for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping) {
vector<OperatorInfo> operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR); vector<OperatorInfo> operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR);
map<string, vector<SgStatement*>> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators); map<string, vector<SgStatement*>> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators);
vector<SgStatement*> newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions); vector<SgStatement*> newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions);
applyOperatorReordering(loopForAnalyze.first, newOrder); applyOperatorReordering(loopForAnalyze.first, newOrder);
} }
} }
std::cout << "SWAP_OPERATORS Pass Completed" << std::endl;
} }