Files
SAPFOR/Sapfor/_src/Utils/module_utils.cpp
2025-03-25 20:39:29 +03:00

665 lines
21 KiB
C++

#include <vector>
#include <set>
#include <map>
#include <string>
#include "dvm.h"
#include "errors.h"
#include "utils.h"
#include "../GraphCall/graph_calls_func.h"
#include "module_utils.h"
using std::vector;
using std::set;
using std::string;
using std::map;
using std::pair;
using std::make_pair;
void findModulesInFile(SgFile* file, vector<SgStatement*>& modules)
{
SgStatement* first = file->firstStatement();
set<SgStatement*> functions;
int funcNum = file->numberOfFunctions();
for (int i = 0; i < funcNum; ++i)
functions.insert(file->functions(i));
while (first)
{
if (first->variant() == MODULE_STMT)
{
modules.push_back(first);
first = first->lastNodeOfStmt();
}
else
{
if (functions.size())
{
auto it = functions.find(first);
if (it != functions.end())
first = (*it)->lastNodeOfStmt();
}
}
first = first->lexNext();
}
}
void getModulesAndFunctions(SgFile* file, vector<SgStatement*>& modulesAndFunctions)
{
findModulesInFile(file, modulesAndFunctions);
int funcNum = file->numberOfFunctions();
for (int i = 0; i < funcNum; ++i)
modulesAndFunctions.push_back(file->functions(i));
}
map<string, set<string>> createMapOfModuleUses(SgFile* file)
{
map<string, set<string>> retValMap;
vector<SgStatement*> modules;
findModulesInFile(file, modules);
for (int z = 0; z < modules.size(); ++z)
{
SgStatement* curr = modules[z];
string modName = curr->symbol()->identifier();
for (SgStatement* st = curr->lexNext(); st != curr->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == USE_STMT)
retValMap[modName].insert(st->symbol()->identifier());
else if (st->variant() == PROC_HEDR || st->variant() == FUNC_HEDR)
break;
}
}
bool repeat = true;
while (repeat)
{
repeat = false;
for (auto& elem : retValMap)
{
set<string> toAdd(elem.second);
for (auto& inUse : elem.second)
{
auto it = retValMap.find(inUse);
if (it != retValMap.end())
{
for (auto& inUseToAdd : it->second)
{
if (toAdd.find(inUseToAdd) == toAdd.end())
{
toAdd.insert(inUseToAdd);
repeat = true;
}
}
}
}
elem.second = toAdd;
}
}
return retValMap;
}
void fillModuleUse(SgFile* file, map<string, set<string>>& moduleUses, map<string, string>& moduleDecls)
{
const string currFN = file->filename();
for (SgStatement* st = file->firstStatement(); st; st = st->lexNext())
{
if (st->fileName() == currFN)
{
if (st->variant() == USE_STMT)
moduleUses[currFN].insert(st->symbol()->identifier());
if (st->variant() == MODULE_STMT)
{
string moduleN = st->symbol()->identifier();
auto it = moduleDecls.find(moduleN);
if (it != moduleDecls.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
moduleDecls[moduleN] = currFN;
}
}
}
}
void filterModuleUse(map<string, set<string>>& moduleUsesByFile, map<string, string>& moduleDecls)
{
for (auto& elem : moduleUsesByFile)
{
set<string> newSet;
for (auto& setElem : elem.second)
{
auto it = moduleDecls.find(setElem);
if (it == moduleDecls.end())
newSet.insert(setElem);
else if (elem.first != it->second)
newSet.insert(setElem);
}
elem.second = newSet;
}
/*map<string, set<string>> modIncludeMod;
for (auto& mod : moduleDecls)
{
string name = mod.first;
string file = mod.second;
auto it = moduleUsesByFile.find(file);
if (it != moduleUsesByFile.end())
modIncludeMod[name] = it->second;
}
bool change = true;
while (change)
{
change = false;
for (auto& mod : modIncludeMod)
{
set<string> newSet = mod.second;
for (auto& included : mod.second)
{
auto it = modIncludeMod.find(included);
if (it == modIncludeMod.end())
continue;
for (auto& elem : it->second)
{
if (newSet.find(elem) == newSet.end())
{
newSet.insert(elem);
change = true;
}
}
}
mod.second = newSet;
}
}
for (auto& elem : moduleUsesByFile)
{
set<string> newSet = elem.second;
for (auto& setElem : elem.second)
{
auto it = modIncludeMod.find(setElem);
if (it != modIncludeMod.end())
for (auto& toRem : it->second)
newSet.erase(toRem);
}
elem.second = newSet;
}*/
}
static void addUseStatements(SgStatement* currF, SgStatement* obj, vector<SgStatement*>& useStats,
const vector<SgStatement*>& funcContains)
{
for (auto& funcSt : funcContains)
{
if (currF == funcSt)
{
SgStatement* last = obj->lastNodeOfStmt();
for (SgStatement* st = obj->lexNext(); st != last; st = st->lexNext())
{
if (st->variant() == USE_STMT)
useStats.push_back(st);
else if (st->variant() == CONTAINS_STMT)
break;
}
break;
}
}
}
void fillUsedModulesInFunction(SgStatement* st, vector<SgStatement*>& useStats)
{
checkNull(st, convertFileName(__FILE__).c_str(), __LINE__);
int var = st->variant();
while (var != PROG_HEDR && var != PROC_HEDR && var != FUNC_HEDR)
{
st = st->controlParent();
checkNull(st, convertFileName(__FILE__).c_str(), __LINE__);
var = st->variant();
}
for (SgStatement* stat = st->lexNext(); !isSgExecutableStatement(stat); stat = stat->lexNext())
if (stat->variant() == USE_STMT)
useStats.push_back(stat);
for (int i = 0; i < current_file->numberOfFunctions(); ++i)
{
vector<SgStatement*> funcContains;
findContainsFunctions(current_file->functions(i), funcContains);
addUseStatements(st, current_file->functions(i), useStats, funcContains);
}
vector<SgStatement*> modules;
findModulesInFile(st->getFile(), modules);
for (auto& module : modules)
{
vector<SgStatement*> funcContains;
findContainsFunctions(module, funcContains, true);
addUseStatements(st, module, useStats, funcContains);
}
}
static void findByUse(map<string, vector<pair<SgSymbol*, SgSymbol*>>>& modByUse, const string& varName,
const set<string>& locNames, vector<string>& altNames)
{
for (auto& elem : modByUse)
{
if (locNames.count(elem.first))
{
for (auto& byUse : elem.second)
{
SgSymbol* toCmp = byUse.second ? byUse.second : byUse.first;
checkNull(toCmp, convertFileName(__FILE__).c_str(), __LINE__);
if (toCmp->identifier() == varName)
altNames.push_back(byUse.first->identifier());
}
}
}
}
static void fillInfo(SgStatement* start,
set<string>& useMod,
map<string, vector<pair<SgSymbol*, SgSymbol*>>>& modByUse,
map<string, vector<pair<SgSymbol*, SgSymbol*>>>& modByUseOnly)
{
for (SgStatement* st = start; st != start->lastNodeOfStmt(); st = st->lexNext())
{
if (isSgExecutableStatement(st))
break;
if (st->variant() == CONTAINS_STMT)
break;
if (st != start && (st->variant() == PROC_HEDR || st->variant() == FUNC_HEDR))
break;
fillUseStatement(st, useMod, modByUse, modByUseOnly);
}
}
static SgStatement* findModWithName(const vector<SgStatement*>& modules, const string& name)
{
for (auto& elem : modules)
if (elem->variant() == MODULE_STMT)
if (elem->symbol()->identifier() == name)
return elem;
return NULL;
}
static map<SgStatement*, set<SgSymbol*>> symbolsForFunc;
static set<string> allFiles;
static void getModuleSymbols(SgStatement* func, set<SgSymbol*>& symbs)
{
SgSymbol* s = func->symbol()->next();
while (s)
{
if (s->scope() == func && IS_BY_USE(s))
symbs.insert(s);
s = s->next();
}
}
const set<SgSymbol*>& getModuleSymbols(SgStatement *func)
{
if (symbolsForFunc.find(func) != symbolsForFunc.end())
return symbolsForFunc[func];
set<SgSymbol*> symbs;
getModuleSymbols(func, symbs);
//if function in contains
func = func->controlParent();
if (isSgProgHedrStmt(func))
getModuleSymbols(func, symbs);
symbolsForFunc[func] = symbs;
return symbolsForFunc[func];
}
SgSymbol* getNameInLocation(SgStatement* func, const string& varName, const string& locName)
{
map<string, SgSymbol*> altNames;
for (const auto& s : getModuleSymbols(func))
{
SgSymbol* orig = OriginalSymbol(s);
//any suitable symbol can be used
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
altNames[s->identifier()] = s;
}
if (altNames.size() > 0)
return altNames.begin()->second;
else {
__spf_print(1, "%s %s %s\n", func->symbol()->identifier(), varName.c_str(), locName.c_str());
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
return NULL;
}
SgSymbol* getNameInLocation(SgSymbol* curr, SgStatement* location)
{
string oldFileName = "";
if (location->getFileId() != current_file_id)
{
oldFileName = current_file->filename();
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
SgStatement* func = getFuncStat(location, { MODULE_STMT });
if (func == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
SgSymbol* returnVal = curr;
if (IS_BY_USE(curr))
{
const string location = OriginalSymbol(curr)->scope()->symbol()->identifier();
returnVal = getNameInLocation(func, OriginalSymbol(curr)->identifier(), location);
}
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return returnVal;
}
namespace Distribution
{
const string Array::GetNameInLocation(void* location_p) const
{
return ((SgSymbol*)GetNameInLocationS(location_p))->identifier();
}
void* Array::GetNameInLocationS(void* location_p) const
{
SgStatement* location = (SgStatement*)location_p;
string oldFileName = "";
if (location->getFileId() != current_file_id)
{
oldFileName = current_file->filename();
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
SgStatement* func = getFuncStat(location, { MODULE_STMT });
if (func == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (allFiles.size() == 0)
allFiles = getAllFilesInProject();
const pair<int, int> lineRange = make_pair(func->lineNumber(), func->lastNodeOfStmt()->lineNumber());
const string& filename = func->fileName();
SgSymbol* returnVal = NULL;
if (locationPos.first == l_MODULE)
{
const string& varName = shortName;
const string& locName = locationPos.second;
returnVal = getNameInLocation(func, varName, locName);
}
else
returnVal = GetDeclSymbol(filename, lineRange, allFiles);
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return returnVal;
}
}
void fixUseOnlyStmt(SgFile *file, const vector<ParallelRegion*> &regs)
{
for (int z = 0; z < file->numberOfFunctions(); ++z)
{
vector<SgStatement*> modules;
findModulesInFile(file, modules);
map<string, SgStatement*> mod;
for (auto &elem : modules)
mod[elem->symbol()->identifier()] = elem;
if (modules.size())
{
SgStatement *func = file->functions(z);
bool hasTemplateUse = false;
set<DIST::Array*> needToAdd;
for (auto st = func; st != func->lastNodeOfStmt(); st = st->lexNext())
{
if (isSgExecutableStatement(st))
break;
if (st->variant() == USE_STMT)
{
SgExpression *ex = st->expr(0);
string modName = st->symbol()->identifier();
auto it = mod.find(modName);
if (modName == "dvmh_Template_Mod")
{
hasTemplateUse = true;
break;
}
if (ex && ex->variant() == ONLY_NODE && it != mod.end())
{
set<string> allS;
for (auto exI = ex->lhs(); exI; exI = exI->rhs())
{
if (exI->lhs()->variant() == RENAME_NODE)
{
if (exI->lhs()->lhs()->symbol())
allS.insert(exI->lhs()->lhs()->symbol()->identifier());
if (exI->lhs()->rhs() && exI->lhs()->rhs()->symbol())
allS.insert(exI->lhs()->rhs()->symbol()->identifier());
}
}
for (auto &parReg : regs)
{
const DataDirective &dataDir = parReg->GetDataDir();
for (auto &rule : dataDir.distrRules)
{
DIST::Array *curr = rule.first;
auto location = curr->GetLocation();
if (location.first == 2 && location.second == modName)
needToAdd.insert(curr);
}
for (auto& rule : dataDir.alignRules)
{
DIST::Array* curr = rule.alignArray;
auto location = curr->GetLocation();
if (location.first == 2 && location.second == modName)
needToAdd.insert(curr);
}
}
}
}
}
if (!hasTemplateUse && needToAdd.size())
{
SgStatement* useSt = new SgStatement(USE_STMT);
useSt->setSymbol(*findSymbolOrCreate(file, "dvmh_Template_Mod"));
useSt->setlineNumber(getNextNegativeLineNumber());
func->insertStmtAfter(*useSt, *func);
}
}
}
}
void fillUseStatement(SgStatement *st, set<string> &useMod,
map<string, vector<pair<SgSymbol*, SgSymbol*>>> &modByUse,
map<string, vector<pair<SgSymbol*, SgSymbol*>>> &modByUseOnly)
{
if (st->variant() == USE_STMT)
{
SgExpression *ex = st->expr(0);
string modName = st->symbol()->identifier();
convertToLower(modName);
useMod.insert(modName);
if (ex)
{
SgExpression *start = ex;
bool only = false;
if (ex->variant() == ONLY_NODE)
{
start = ex->lhs();
only = true;
}
for (auto exI = start; exI; exI = exI->rhs())
{
if (exI->lhs()->variant() == RENAME_NODE)
{
SgSymbol *left = NULL, *right = NULL;
if (exI->lhs()->lhs()->symbol())
left = exI->lhs()->lhs()->symbol();
if (exI->lhs()->rhs() && exI->lhs()->rhs()->symbol())
right = exI->lhs()->rhs()->symbol();
if (only)
modByUseOnly[modName].push_back(std::make_pair(left, right));
else
modByUse[modName].push_back(std::make_pair(left, right));
}
}
}
}
}
static void fillUseStmt(SgStatement* stat, map<string, set<SgSymbol*>>& byUse)
{
if (stat->variant() != USE_STMT)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* ex = stat->expr(0);
if (ex && ex->variant() == ONLY_NODE)
{
for (auto exI = ex->lhs(); exI; exI = exI->rhs())
{
if (exI->lhs()->variant() == RENAME_NODE)
{
SgExpression* ren = exI->lhs();
if (ren->lhs()->symbol() && ren->rhs() && ren->rhs()->symbol())
byUse[ren->rhs()->symbol()->identifier()].insert(ren->lhs()->symbol());
}
}
}
else if (ex && ex->lhs())
{
for (auto exI = ex; exI; exI = exI->rhs())
{
if (exI->lhs()->variant() == RENAME_NODE)
{
SgExpression* ren = exI->lhs();
if (ren->lhs()->symbol() && ren->rhs() && ren->rhs()->symbol())
byUse[ren->rhs()->symbol()->identifier()].insert(ren->lhs()->symbol());
}
}
}
}
map<string, set<SgSymbol*>> moduleRefsByUseInFunction(SgStatement* stIn)
{
checkNull(stIn, convertFileName(__FILE__).c_str(), __LINE__);
map<string, set<SgSymbol*>> byUse;
int var = stIn->variant();
while (var != PROG_HEDR && var != PROC_HEDR && var != FUNC_HEDR)
{
stIn = stIn->controlParent();
if (stIn == NULL)
return byUse;
var = stIn->variant();
}
auto mapOfUses = createMapOfModuleUses(stIn->getFile());
set<string> useMods;
for (SgStatement* stat = stIn->lexNext(); !isSgExecutableStatement(stat); stat = stat->lexNext())
{
if (stat->variant() == USE_STMT)
{
fillUseStmt(stat, byUse);
useMods.insert(stat->symbol()->identifier());
}
}
const int cpOfSt = stIn->controlParent()->variant();
//contains of func
if (cpOfSt == PROG_HEDR || cpOfSt == PROC_HEDR || cpOfSt == FUNC_HEDR)
{
for (SgStatement* stat = stIn->controlParent()->lexNext(); !isSgExecutableStatement(stat); stat = stat->lexNext())
{
if (stat->variant() == USE_STMT)
{
fillUseStmt(stat, byUse);
useMods.insert(stat->symbol()->identifier());
}
}
}
bool chages = true;
while (chages)
{
chages = false;
set<string> newUseMods(useMods);
for (auto& elem : useMods)
{
auto it = mapOfUses.find(elem);
if (it != mapOfUses.end())
{
for (auto& elem2 : it->second)
{
if (newUseMods.find(elem2) == newUseMods.end())
{
newUseMods.insert(elem2);
chages = true;
}
}
}
}
useMods = newUseMods;
}
vector<SgStatement*> modules;
findModulesInFile(stIn->getFile(), modules);
for (auto& mod : modules)
{
if (useMods.find(mod->symbol()->identifier()) != useMods.end())
{
for (SgStatement* stat = mod->lexNext(); stat != mod->lastNodeOfStmt(); stat = stat->lexNext())
{
const int var = stat->variant();
if (var == USE_STMT)
{
fillUseStmt(stat, byUse);
useMods.insert(stat->symbol()->identifier());
}
else if (var == PROC_HEDR || var == FUNC_HEDR)
break;
}
}
}
return byUse;
}