swap operators in AST

This commit is contained in:
Egor Mayorov
2025-05-27 15:55:02 +03:00
parent 7fed8f13c3
commit a86a76705e

View File

@@ -43,7 +43,7 @@ vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<F
{
vector<SAPFOR::BasicBlock*> result;
Statement* forSt = (Statement*)st;
for (auto& func : FullIR)
for (auto& func: FullIR)
{
if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile()
&& func.first -> funcPointer -> lineNumber() == forSt -> lineNumber())
@@ -86,20 +86,25 @@ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st
map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector<SAPFOR::BasicBlock*> loopBlocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{
map<SgStatement*, set<SgStatement*>> result;
for (SAPFOR::BasicBlock* bb: loopBlocks) {
for (SAPFOR::BasicBlock* bb: loopBlocks)
{
map<SAPFOR::Argument*, set<int>> blockReachingDefinitions = bb -> getRD_In();
vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions();
for (SAPFOR::IR_Block* irBlock: instructions) {
for (SAPFOR::IR_Block* irBlock: instructions)
{
// TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account
SAPFOR::Instruction* instr = irBlock -> getInstruction();
result[instr -> getOperator()];
// take Argument 1 and it's RD and push operators to final set
if (instr -> getArg1() != NULL) {
if (instr -> getArg1() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg1();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) {
for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
@@ -108,12 +113,15 @@ map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatem
}
}
// take Argument 2 (if exists) and it's RD and push operators to final set
if (instr -> getArg2() != NULL) {
if (instr -> getArg2() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg2();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) {
for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
@@ -139,27 +147,37 @@ void buildAdditionalDeps(SgForStmt* forStatement, map<SgStatement*, set<SgStatem
while (st && st != lastNode)
{
if(importantDeps.size() != 0)
if (st != importantDeps.back()) {
{
if (st != importantDeps.back())
{
dependencies[st].insert(importantDeps.back());
}
if (logIfOp != NULL) {
}
if (logIfOp != NULL)
{
dependencies[st].insert(logIfOp);
logIfOp = NULL;
}
if (st -> variant() == LOGIF_NODE) {
if (st -> variant() == LOGIF_NODE)
{
logIfOp = st;
}
if (importantDepsTags.find(st -> variant()) != importantDepsTags.end()) {
if (importantDepsTags.find(st -> variant()) != importantDepsTags.end())
{
importantDeps.push_back(st);
}
if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end()) {
if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end())
{
importantDeps.pop_back();
importantDeps.push_back(st);
}
if (importantEndTags.find(st -> variant()) != importantEndTags.end()) {
if (importantEndTags.find(st -> variant()) != importantEndTags.end())
{
if(importantDeps.size() != 0)
{
importantDeps.pop_back();
}
}
st = st -> lexNext();
}
}
@@ -180,16 +198,18 @@ struct ReadyOpCompare {
}
};
vector<SgStatement*> scheduleOperations(
const map<SgStatement*, set<SgStatement*>>& dependencies
) {
vector<SgStatement*> scheduleOperations(const map<SgStatement*, set<SgStatement*>>& dependencies)
{
// get all statements
unordered_set<SgStatement*> allStmtsSet;
for (const auto& pair : dependencies) {
for (const auto& pair : dependencies)
{
allStmtsSet.insert(pair.first);
for (SgStatement* dep : pair.second)
{
allStmtsSet.insert(dep);
}
}
vector<SgStatement*> allStmts(allStmtsSet.begin(), allStmtsSet.end());
// count deps and build reversed graph
unordered_map<SgStatement*, vector<SgStatement*>> graph;
@@ -199,7 +219,8 @@ vector<SgStatement*> scheduleOperations(
inDegree[op] = 0;
// find and remember initial dependencies
unordered_set<SgStatement*> dependentStmts;
for (const auto& pair : dependencies) {
for (const auto& pair : dependencies)
{
SgStatement* op = pair.first;
const auto& deps = pair.second;
degree[op] = deps.size();
@@ -210,40 +231,57 @@ vector<SgStatement*> scheduleOperations(
graph[dep].push_back(op);
}
for (SgStatement* op : allStmts)
{
if (!degree.count(op))
{
degree[op] = 0;
}
}
// build queues
using PQ = priority_queue<ReadyOp, vector<ReadyOp>, ReadyOpCompare>;
PQ readyDependent;
queue<SgStatement*> readyIndependent;
size_t arrivalCounter = 0;
for (auto op : allStmts) {
if (inDegree[op] == 0) {
if (dependentStmts.count(op)) {
for (auto op : allStmts)
{
if (inDegree[op] == 0)
{
if (dependentStmts.count(op))
{
readyDependent.emplace(op, degree[op], arrivalCounter++);
} else {
}
else
{
readyIndependent.push(op);
}
}
}
// main sort algorythm
vector<SgStatement*> executionOrder;
while (!readyDependent.empty() || !readyIndependent.empty()) {
while (!readyDependent.empty() || !readyIndependent.empty())
{
SgStatement* current = nullptr;
if (!readyDependent.empty()) {
if (!readyDependent.empty())
{
current = readyDependent.top().stmt;
readyDependent.pop();
} else {
}
else
{
current = readyIndependent.front();
readyIndependent.pop();
}
executionOrder.push_back(current);
for (SgStatement* neighbor : graph[current]) {
for (SgStatement* neighbor : graph[current])
{
inDegree[neighbor]--;
if (inDegree[neighbor] == 0) {
if (dependentStmts.count(neighbor)) {
if (dependentStmts.count(neighbor))
{
readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++);
} else {
}
else
{
readyIndependent.push(neighbor);
}
}
@@ -252,6 +290,21 @@ vector<SgStatement*> scheduleOperations(
return executionOrder;
}
void buildNewAST(SgStatement* loop, vector<SgStatement*>& newBody)
{
SgStatement* endDo = loop->lastNodeOfStmt();
SgStatement* st = loop;
int lineNum = loop -> lineNumber() + 1;
for (int i = 0; i < newBody.size(); i++)
{
st -> setLexNext(*newBody[i]);
st = st -> lexNext();
st -> setlineNumber(lineNum);
lineNum++;
}
st -> setLexNext(*endDo);
}
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform)
@@ -286,6 +339,13 @@ void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*
for (auto v: new_order)
if (v -> lineNumber() > firstLine)
cout << v -> lineNumber() << endl;
buildNewAST(loopForAnalyze.first, new_order);
st = loopForAnalyze.first -> lexNext();
while (st != loopForAnalyze.first -> lastNodeOfStmt())
{
cout << st -> lineNumber() << " " << st -> sunparse() << endl;
st = st -> lexNext();
}
}
}