improved module analysis

This commit is contained in:
ALEXks
2025-02-18 18:57:05 +03:00
parent 6b0eaab96d
commit 7b12fb1bb0
5 changed files with 80 additions and 61 deletions

View File

@@ -872,7 +872,6 @@ ArrayRefExp* createRemoteLink(const LoopGraph* currLoop, const DIST::Array* forA
const set<string> allFiles = getAllFilesInProject();
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
SgExpression* ex = new SgExpression(EXPR_LIST);
SgExpression* p = ex;

View File

@@ -226,7 +226,6 @@ static SgStatement* getModuleScope(const string& origFull, vector<SgStatement*>&
static vector<SgExpression*>
compliteTieList(const LoopGraph* currLoop, const vector<LoopGraph*>& loops,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls,
const map<string, set<SgSymbol*>>& byUseInFunc,
File* file, SgStatement *location,
const set<DIST::Array*>& onlyFor,
const set<string>& privates)
@@ -506,7 +505,6 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
SgStatement* parentFunc = getFuncStat(realStat);
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
const int nested = countPerfectLoopNest(loopG);
vector<SgSymbol*> loopSymbs;
@@ -666,7 +664,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
continue;
directive += (k != 0) ? "," + privVar.first : privVar.first;
list.push_back(new SgVarRefExp(getFromModule(byUseInFunc, privVar.second)));
list.push_back(new SgVarRefExp(getNameInLocation(privVar.second, parentFunc)));
++k;
}
directive += ")";
@@ -694,9 +692,9 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
}
vector<SgExpression*> tieList;
if (sharedMemoryParallelization)
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates);
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
else if (onlyFor.size()) // not MPI regime
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates);
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
if (tieList.size())
{
@@ -950,16 +948,21 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
for (auto it = reduction.begin(); it != reduction.end(); ++it)
{
const string& nameGroup = it->first;
for (auto& list : it->second)
for (auto& red : it->second)
{
if (k != 0)
{
directive += ",";
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
SgSymbol* redS;
string clearName = correctSymbolModuleName(red);
if (clearName != red)
redS = getNameInLocation(parentFunc, clearName, getModuleScope(red, moduleList, parentFunc)->symbol()->identifier());
else
redS = findSymbolOrCreate(file, clearName, NULL, parentFunc);
SgSymbol* base = findSymbolOrCreate(file, correctSymbolModuleName(list), NULL, getModuleScope(list, moduleList, parentFunc));
SgSymbol* redS = getFromModule(byUseInFunc, base, list.find("::") != string::npos);
directive += nameGroup + "(" + redS->identifier() + ")";
SgVarRefExp* tmp2 = new SgVarRefExp(redS);
@@ -1009,11 +1012,19 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
SgSymbol* base1 = findSymbolOrCreate(file, correctSymbolModuleName(get<0>(list)), NULL, getModuleScope(get<0>(list), moduleList, parentFunc));
SgSymbol* base2 = findSymbolOrCreate(file, correctSymbolModuleName(get<1>(list)), NULL, getModuleScope(get<1>(list), moduleList, parentFunc));
SgSymbol *redS1, *redS2;
string clearName1 = correctSymbolModuleName(get<0>(list));
string clearName2 = correctSymbolModuleName(get<1>(list));
SgSymbol* redS1 = getFromModule(byUseInFunc, base1, get<0>(list).find("::") != string::npos);
SgSymbol* redS2 = getFromModule(byUseInFunc, base2, get<1>(list).find("::") != string::npos);
if (clearName1 != get<0>(list))
redS1 = getNameInLocation(parentFunc, clearName1, getModuleScope(get<0>(list), moduleList, parentFunc)->symbol()->identifier());
else
redS1 = findSymbolOrCreate(file, clearName1, NULL, parentFunc);
if (clearName2 != get<1>(list))
redS2 = getNameInLocation(parentFunc, clearName2, getModuleScope(get<1>(list), moduleList, parentFunc)->symbol()->identifier());
else
redS2 = findSymbolOrCreate(file, clearName2, NULL, parentFunc);
directive += nameGroup + "(" + redS1->identifier() + ", " + redS2->identifier() + ", " + std::to_string(get<2>(list)) + ")";

View File

@@ -57,27 +57,6 @@ void getModulesAndFunctions(SgFile* file, vector<SgStatement*>& modulesAndFuncti
modulesAndFunctions.push_back(file->functions(i));
}
SgSymbol* getFromModule(const map<string, set<SgSymbol*>>& byUse, SgSymbol* orig, bool processAsModule)
{
if (!processAsModule)
{
checkNull(orig->scope(), convertFileName(__FILE__).c_str(), __LINE__);
if (orig->scope()->variant() != MODULE_STMT)
return orig;
}
if (byUse.size())
{
for (auto& elem : byUse)
{
for (auto& localS : setToMapWithSortByStr(elem.second))
if (OriginalSymbol(localS.second)->thesymb == orig->thesymb)
return localS.second;
}
}
return orig;
}
map<string, set<string>> createMapOfModuleUses(SgFile* file)
{
map<string, set<string>> retValMap;
@@ -335,6 +314,57 @@ static const set<SgSymbol*>& getModeulSymbols(SgStatement *func)
return symbolsForFunc[func];
}
SgSymbol* getNameInLocation(SgStatement* func, const string& varName, const string& locName)
{
map<string, SgSymbol*> altNames;
for (const auto& s : getModeulSymbols(func))
{
SgSymbol* orig = OriginalSymbol(s);
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
{
if (altNames.count(s->identifier()))
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
altNames[s->identifier()] = s;
}
}
if (altNames.size() > 0)
return altNames.begin()->second;
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return NULL;
}
SgSymbol* getNameInLocation(SgSymbol* curr, SgStatement* location)
{
string oldFileName = "";
if (location->getFileId() != current_file_id)
{
oldFileName = current_file->filename();
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
SgStatement* func = getFuncStat(location, { MODULE_STMT });
if (func == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
SgSymbol* returnVal = curr;
if (IS_BY_USE(curr))
{
const string location = OriginalSymbol(curr)->scope()->symbol()->identifier();
returnVal = getNameInLocation(func, OriginalSymbol(curr)->identifier(), location);
}
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return returnVal;
}
namespace Distribution
{
@@ -347,11 +377,9 @@ namespace Distribution
{
SgStatement* location = (SgStatement*)location_p;
int old_id = -1;
string oldFileName = "";
if (location->getFileId() != current_file_id)
{
old_id = current_file_id;
oldFileName = current_file->filename();
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
@@ -374,35 +402,15 @@ namespace Distribution
const string& varName = shortName;
const string& locName = locationPos.second;
map<string, SgSymbol*> altNames;
for (const auto& s : getModeulSymbols(func))
{
SgSymbol* orig = OriginalSymbol(s);
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
{
if (altNames.count(s->identifier()))
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
altNames[s->identifier()] = s;
}
}
if (altNames.size() > 0)
returnVal = altNames.begin()->second;
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
returnVal = getNameInLocation(func, varName, locName);
}
else
returnVal = GetDeclSymbol(filename, lineRange, allFiles);
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (old_id != -1)
{
if (SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return returnVal;
}

View File

@@ -2,10 +2,11 @@
void getModulesAndFunctions(SgFile* file, std::vector<SgStatement*>& modulesAndFunctions);
void findModulesInFile(SgFile* file, std::vector<SgStatement*>& modules);
SgSymbol* getFromModule(const std::map<std::string, std::set<SgSymbol*>>& byUse, SgSymbol* orig, bool processAsModule = false);
std::map<std::string, std::set<std::string>> createMapOfModuleUses(SgFile* file);
void fillModuleUse(SgFile* file, std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
void filterModuleUse(std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
SgSymbol* getNameInLocation(SgStatement* func, const std::string& varName, const std::string& locName);
SgSymbol* getNameInLocation(SgSymbol* curr, SgStatement* location);
void fillUsedModulesInFunction(SgStatement* st, std::vector<SgStatement*>& useStats);
void fillUseStatement(SgStatement* st, std::set<std::string>& useMod, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUse, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUseOnly);
void fixUseOnlyStmt(SgFile* file, const std::vector<ParallelRegion*>& regs);

View File

@@ -1,3 +1,3 @@
#pragma once
#define VERSION_SPF "2391"
#define VERSION_SPF "2392"