swap operators in AST

This commit is contained in:
Egor Mayorov
2025-05-27 15:55:02 +03:00
committed by ALEXks
parent 32a4a7fd0a
commit 8c6a55463c

View File

@@ -86,20 +86,25 @@ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st
map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector<SAPFOR::BasicBlock*> loopBlocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector<SAPFOR::BasicBlock*> loopBlocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{ {
map<SgStatement*, set<SgStatement*>> result; map<SgStatement*, set<SgStatement*>> result;
for (SAPFOR::BasicBlock* bb: loopBlocks) { for (SAPFOR::BasicBlock* bb: loopBlocks)
{
map<SAPFOR::Argument*, set<int>> blockReachingDefinitions = bb -> getRD_In(); map<SAPFOR::Argument*, set<int>> blockReachingDefinitions = bb -> getRD_In();
vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions(); vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions();
for (SAPFOR::IR_Block* irBlock: instructions) { for (SAPFOR::IR_Block* irBlock: instructions)
{
// TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account // TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account
SAPFOR::Instruction* instr = irBlock -> getInstruction(); SAPFOR::Instruction* instr = irBlock -> getInstruction();
result[instr -> getOperator()]; result[instr -> getOperator()];
// take Argument 1 and it's RD and push operators to final set // take Argument 1 and it's RD and push operators to final set
if (instr -> getArg1() != NULL) { if (instr -> getArg1() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg1(); SAPFOR::Argument* arg = instr -> getArg1();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg]; set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) { for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first; SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) { if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator(); SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber() if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber()) && prevOp -> lineNumber() > forStatement -> lineNumber())
@@ -108,12 +113,15 @@ map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatem
} }
} }
// take Argument 2 (if exists) and it's RD and push operators to final set // take Argument 2 (if exists) and it's RD and push operators to final set
if (instr -> getArg2() != NULL) { if (instr -> getArg2() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg2(); SAPFOR::Argument* arg = instr -> getArg2();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg]; set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) { for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first; SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) { if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator(); SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber() if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber()) && prevOp -> lineNumber() > forStatement -> lineNumber())
@@ -139,27 +147,37 @@ void buildAdditionalDeps(SgForStmt* forStatement, map<SgStatement*, set<SgStatem
while (st && st != lastNode) while (st && st != lastNode)
{ {
if(importantDeps.size() != 0) if(importantDeps.size() != 0)
if (st != importantDeps.back()) { {
if (st != importantDeps.back())
{
dependencies[st].insert(importantDeps.back()); dependencies[st].insert(importantDeps.back());
} }
if (logIfOp != NULL) { }
if (logIfOp != NULL)
{
dependencies[st].insert(logIfOp); dependencies[st].insert(logIfOp);
logIfOp = NULL; logIfOp = NULL;
} }
if (st -> variant() == LOGIF_NODE) { if (st -> variant() == LOGIF_NODE)
{
logIfOp = st; logIfOp = st;
} }
if (importantDepsTags.find(st -> variant()) != importantDepsTags.end()) { if (importantDepsTags.find(st -> variant()) != importantDepsTags.end())
{
importantDeps.push_back(st); importantDeps.push_back(st);
} }
if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end()) { if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end())
{
importantDeps.pop_back(); importantDeps.pop_back();
importantDeps.push_back(st); importantDeps.push_back(st);
} }
if (importantEndTags.find(st -> variant()) != importantEndTags.end()) { if (importantEndTags.find(st -> variant()) != importantEndTags.end())
{
if(importantDeps.size() != 0) if(importantDeps.size() != 0)
{
importantDeps.pop_back(); importantDeps.pop_back();
} }
}
st = st -> lexNext(); st = st -> lexNext();
} }
} }
@@ -180,16 +198,18 @@ struct ReadyOpCompare {
} }
}; };
vector<SgStatement*> scheduleOperations( vector<SgStatement*> scheduleOperations(const map<SgStatement*, set<SgStatement*>>& dependencies)
const map<SgStatement*, set<SgStatement*>>& dependencies {
) {
// get all statements // get all statements
unordered_set<SgStatement*> allStmtsSet; unordered_set<SgStatement*> allStmtsSet;
for (const auto& pair : dependencies) { for (const auto& pair : dependencies)
{
allStmtsSet.insert(pair.first); allStmtsSet.insert(pair.first);
for (SgStatement* dep : pair.second) for (SgStatement* dep : pair.second)
{
allStmtsSet.insert(dep); allStmtsSet.insert(dep);
} }
}
vector<SgStatement*> allStmts(allStmtsSet.begin(), allStmtsSet.end()); vector<SgStatement*> allStmts(allStmtsSet.begin(), allStmtsSet.end());
// count deps and build reversed graph // count deps and build reversed graph
unordered_map<SgStatement*, vector<SgStatement*>> graph; unordered_map<SgStatement*, vector<SgStatement*>> graph;
@@ -199,7 +219,8 @@ vector<SgStatement*> scheduleOperations(
inDegree[op] = 0; inDegree[op] = 0;
// find and remember initial dependencies // find and remember initial dependencies
unordered_set<SgStatement*> dependentStmts; unordered_set<SgStatement*> dependentStmts;
for (const auto& pair : dependencies) { for (const auto& pair : dependencies)
{
SgStatement* op = pair.first; SgStatement* op = pair.first;
const auto& deps = pair.second; const auto& deps = pair.second;
degree[op] = deps.size(); degree[op] = deps.size();
@@ -210,40 +231,57 @@ vector<SgStatement*> scheduleOperations(
graph[dep].push_back(op); graph[dep].push_back(op);
} }
for (SgStatement* op : allStmts) for (SgStatement* op : allStmts)
{
if (!degree.count(op)) if (!degree.count(op))
{
degree[op] = 0; degree[op] = 0;
}
}
// build queues // build queues
using PQ = priority_queue<ReadyOp, vector<ReadyOp>, ReadyOpCompare>; using PQ = priority_queue<ReadyOp, vector<ReadyOp>, ReadyOpCompare>;
PQ readyDependent; PQ readyDependent;
queue<SgStatement*> readyIndependent; queue<SgStatement*> readyIndependent;
size_t arrivalCounter = 0; size_t arrivalCounter = 0;
for (auto op : allStmts) { for (auto op : allStmts)
if (inDegree[op] == 0) { {
if (dependentStmts.count(op)) { if (inDegree[op] == 0)
{
if (dependentStmts.count(op))
{
readyDependent.emplace(op, degree[op], arrivalCounter++); readyDependent.emplace(op, degree[op], arrivalCounter++);
} else { }
else
{
readyIndependent.push(op); readyIndependent.push(op);
} }
} }
} }
// main sort algorythm // main sort algorythm
vector<SgStatement*> executionOrder; vector<SgStatement*> executionOrder;
while (!readyDependent.empty() || !readyIndependent.empty()) { while (!readyDependent.empty() || !readyIndependent.empty())
{
SgStatement* current = nullptr; SgStatement* current = nullptr;
if (!readyDependent.empty()) { if (!readyDependent.empty())
{
current = readyDependent.top().stmt; current = readyDependent.top().stmt;
readyDependent.pop(); readyDependent.pop();
} else { }
else
{
current = readyIndependent.front(); current = readyIndependent.front();
readyIndependent.pop(); readyIndependent.pop();
} }
executionOrder.push_back(current); executionOrder.push_back(current);
for (SgStatement* neighbor : graph[current]) { for (SgStatement* neighbor : graph[current])
{
inDegree[neighbor]--; inDegree[neighbor]--;
if (inDegree[neighbor] == 0) { if (inDegree[neighbor] == 0) {
if (dependentStmts.count(neighbor)) { if (dependentStmts.count(neighbor))
{
readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++); readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++);
} else { }
else
{
readyIndependent.push(neighbor); readyIndependent.push(neighbor);
} }
} }
@@ -252,6 +290,21 @@ vector<SgStatement*> scheduleOperations(
return executionOrder; return executionOrder;
} }
void buildNewAST(SgStatement* loop, vector<SgStatement*>& newBody)
{
SgStatement* endDo = loop->lastNodeOfStmt();
SgStatement* st = loop;
int lineNum = loop -> lineNumber() + 1;
for (int i = 0; i < newBody.size(); i++)
{
st -> setLexNext(*newBody[i]);
st = st -> lexNext();
st -> setlineNumber(lineNum);
lineNum++;
}
st -> setLexNext(*endDo);
}
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform)
@@ -286,6 +339,13 @@ void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*
for (auto v: new_order) for (auto v: new_order)
if (v -> lineNumber() > firstLine) if (v -> lineNumber() > firstLine)
cout << v -> lineNumber() << endl; cout << v -> lineNumber() << endl;
buildNewAST(loopForAnalyze.first, new_order);
st = loopForAnalyze.first -> lexNext();
while (st != loopForAnalyze.first -> lastNodeOfStmt())
{
cout << st -> lineNumber() << " " << st -> sunparse() << endl;
st = st -> lexNext();
}
} }
} }