improved module analysis

This commit is contained in:
ALEXks
2025-02-18 18:57:05 +03:00
committed by Dudarenko
parent 09401376c7
commit 68d2f3253c
5 changed files with 80 additions and 61 deletions

View File

@@ -872,7 +872,6 @@ ArrayRefExp* createRemoteLink(const LoopGraph* currLoop, const DIST::Array* forA
const set<string> allFiles = getAllFilesInProject(); const set<string> allFiles = getAllFilesInProject();
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename()); SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
SgExpression* ex = new SgExpression(EXPR_LIST); SgExpression* ex = new SgExpression(EXPR_LIST);
SgExpression* p = ex; SgExpression* p = ex;

View File

@@ -226,7 +226,6 @@ static SgStatement* getModuleScope(const string& origFull, vector<SgStatement*>&
static vector<SgExpression*> static vector<SgExpression*>
compliteTieList(const LoopGraph* currLoop, const vector<LoopGraph*>& loops, compliteTieList(const LoopGraph* currLoop, const vector<LoopGraph*>& loops,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls, const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls,
const map<string, set<SgSymbol*>>& byUseInFunc,
File* file, SgStatement *location, File* file, SgStatement *location,
const set<DIST::Array*>& onlyFor, const set<DIST::Array*>& onlyFor,
const set<string>& privates) const set<string>& privates)
@@ -506,7 +505,6 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename()); SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
SgStatement* parentFunc = getFuncStat(realStat); SgStatement* parentFunc = getFuncStat(realStat);
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
const int nested = countPerfectLoopNest(loopG); const int nested = countPerfectLoopNest(loopG);
vector<SgSymbol*> loopSymbs; vector<SgSymbol*> loopSymbs;
@@ -666,7 +664,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
continue; continue;
directive += (k != 0) ? "," + privVar.first : privVar.first; directive += (k != 0) ? "," + privVar.first : privVar.first;
list.push_back(new SgVarRefExp(getFromModule(byUseInFunc, privVar.second))); list.push_back(new SgVarRefExp(getNameInLocation(privVar.second, parentFunc)));
++k; ++k;
} }
directive += ")"; directive += ")";
@@ -694,9 +692,9 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
} }
vector<SgExpression*> tieList; vector<SgExpression*> tieList;
if (sharedMemoryParallelization) if (sharedMemoryParallelization)
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates); tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
else if (onlyFor.size()) // not MPI regime else if (onlyFor.size()) // not MPI regime
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates); tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
if (tieList.size()) if (tieList.size())
{ {
@@ -950,7 +948,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
for (auto it = reduction.begin(); it != reduction.end(); ++it) for (auto it = reduction.begin(); it != reduction.end(); ++it)
{ {
const string& nameGroup = it->first; const string& nameGroup = it->first;
for (auto& list : it->second) for (auto& red : it->second)
{ {
if (k != 0) if (k != 0)
{ {
@@ -958,8 +956,13 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
p = createAndSetNext(RIGHT, EXPR_LIST, p); p = createAndSetNext(RIGHT, EXPR_LIST, p);
} }
SgSymbol* base = findSymbolOrCreate(file, correctSymbolModuleName(list), NULL, getModuleScope(list, moduleList, parentFunc)); SgSymbol* redS;
SgSymbol* redS = getFromModule(byUseInFunc, base, list.find("::") != string::npos); string clearName = correctSymbolModuleName(red);
if (clearName != red)
redS = getNameInLocation(parentFunc, clearName, getModuleScope(red, moduleList, parentFunc)->symbol()->identifier());
else
redS = findSymbolOrCreate(file, clearName, NULL, parentFunc);
directive += nameGroup + "(" + redS->identifier() + ")"; directive += nameGroup + "(" + redS->identifier() + ")";
SgVarRefExp* tmp2 = new SgVarRefExp(redS); SgVarRefExp* tmp2 = new SgVarRefExp(redS);
@@ -1009,11 +1012,19 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
p = createAndSetNext(RIGHT, EXPR_LIST, p); p = createAndSetNext(RIGHT, EXPR_LIST, p);
} }
SgSymbol* base1 = findSymbolOrCreate(file, correctSymbolModuleName(get<0>(list)), NULL, getModuleScope(get<0>(list), moduleList, parentFunc)); SgSymbol *redS1, *redS2;
SgSymbol* base2 = findSymbolOrCreate(file, correctSymbolModuleName(get<1>(list)), NULL, getModuleScope(get<1>(list), moduleList, parentFunc)); string clearName1 = correctSymbolModuleName(get<0>(list));
string clearName2 = correctSymbolModuleName(get<1>(list));
SgSymbol* redS1 = getFromModule(byUseInFunc, base1, get<0>(list).find("::") != string::npos); if (clearName1 != get<0>(list))
SgSymbol* redS2 = getFromModule(byUseInFunc, base2, get<1>(list).find("::") != string::npos); redS1 = getNameInLocation(parentFunc, clearName1, getModuleScope(get<0>(list), moduleList, parentFunc)->symbol()->identifier());
else
redS1 = findSymbolOrCreate(file, clearName1, NULL, parentFunc);
if (clearName2 != get<1>(list))
redS2 = getNameInLocation(parentFunc, clearName2, getModuleScope(get<1>(list), moduleList, parentFunc)->symbol()->identifier());
else
redS2 = findSymbolOrCreate(file, clearName2, NULL, parentFunc);
directive += nameGroup + "(" + redS1->identifier() + ", " + redS2->identifier() + ", " + std::to_string(get<2>(list)) + ")"; directive += nameGroup + "(" + redS1->identifier() + ", " + redS2->identifier() + ", " + std::to_string(get<2>(list)) + ")";

View File

@@ -57,27 +57,6 @@ void getModulesAndFunctions(SgFile* file, vector<SgStatement*>& modulesAndFuncti
modulesAndFunctions.push_back(file->functions(i)); modulesAndFunctions.push_back(file->functions(i));
} }
SgSymbol* getFromModule(const map<string, set<SgSymbol*>>& byUse, SgSymbol* orig, bool processAsModule)
{
if (!processAsModule)
{
checkNull(orig->scope(), convertFileName(__FILE__).c_str(), __LINE__);
if (orig->scope()->variant() != MODULE_STMT)
return orig;
}
if (byUse.size())
{
for (auto& elem : byUse)
{
for (auto& localS : setToMapWithSortByStr(elem.second))
if (OriginalSymbol(localS.second)->thesymb == orig->thesymb)
return localS.second;
}
}
return orig;
}
map<string, set<string>> createMapOfModuleUses(SgFile* file) map<string, set<string>> createMapOfModuleUses(SgFile* file)
{ {
map<string, set<string>> retValMap; map<string, set<string>> retValMap;
@@ -335,6 +314,57 @@ static const set<SgSymbol*>& getModeulSymbols(SgStatement *func)
return symbolsForFunc[func]; return symbolsForFunc[func];
} }
SgSymbol* getNameInLocation(SgStatement* func, const string& varName, const string& locName)
{
map<string, SgSymbol*> altNames;
for (const auto& s : getModeulSymbols(func))
{
SgSymbol* orig = OriginalSymbol(s);
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
{
if (altNames.count(s->identifier()))
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
altNames[s->identifier()] = s;
}
}
if (altNames.size() > 0)
return altNames.begin()->second;
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return NULL;
}
SgSymbol* getNameInLocation(SgSymbol* curr, SgStatement* location)
{
string oldFileName = "";
if (location->getFileId() != current_file_id)
{
oldFileName = current_file->filename();
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
SgStatement* func = getFuncStat(location, { MODULE_STMT });
if (func == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
SgSymbol* returnVal = curr;
if (IS_BY_USE(curr))
{
const string location = OriginalSymbol(curr)->scope()->symbol()->identifier();
returnVal = getNameInLocation(func, OriginalSymbol(curr)->identifier(), location);
}
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return returnVal;
}
namespace Distribution namespace Distribution
{ {
@@ -347,11 +377,9 @@ namespace Distribution
{ {
SgStatement* location = (SgStatement*)location_p; SgStatement* location = (SgStatement*)location_p;
int old_id = -1;
string oldFileName = ""; string oldFileName = "";
if (location->getFileId() != current_file_id) if (location->getFileId() != current_file_id)
{ {
old_id = current_file_id;
oldFileName = current_file->filename(); oldFileName = current_file->filename();
if (!location->switchToFile()) if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__); printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
@@ -374,35 +402,15 @@ namespace Distribution
const string& varName = shortName; const string& varName = shortName;
const string& locName = locationPos.second; const string& locName = locationPos.second;
returnVal = getNameInLocation(func, varName, locName);
map<string, SgSymbol*> altNames;
for (const auto& s : getModeulSymbols(func))
{
SgSymbol* orig = OriginalSymbol(s);
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
{
if (altNames.count(s->identifier()))
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
altNames[s->identifier()] = s;
}
}
if (altNames.size() > 0)
returnVal = altNames.begin()->second;
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
} }
else else
returnVal = GetDeclSymbol(filename, lineRange, allFiles); returnVal = GetDeclSymbol(filename, lineRange, allFiles);
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__); checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (old_id != -1) if (oldFileName != "" && SgFile::switchToFile(oldFileName) == -1)
{
if (SgFile::switchToFile(oldFileName) == -1)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__); printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
return returnVal; return returnVal;
} }

View File

@@ -2,10 +2,11 @@
void getModulesAndFunctions(SgFile* file, std::vector<SgStatement*>& modulesAndFunctions); void getModulesAndFunctions(SgFile* file, std::vector<SgStatement*>& modulesAndFunctions);
void findModulesInFile(SgFile* file, std::vector<SgStatement*>& modules); void findModulesInFile(SgFile* file, std::vector<SgStatement*>& modules);
SgSymbol* getFromModule(const std::map<std::string, std::set<SgSymbol*>>& byUse, SgSymbol* orig, bool processAsModule = false);
std::map<std::string, std::set<std::string>> createMapOfModuleUses(SgFile* file); std::map<std::string, std::set<std::string>> createMapOfModuleUses(SgFile* file);
void fillModuleUse(SgFile* file, std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls); void fillModuleUse(SgFile* file, std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
void filterModuleUse(std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls); void filterModuleUse(std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
SgSymbol* getNameInLocation(SgStatement* func, const std::string& varName, const std::string& locName);
SgSymbol* getNameInLocation(SgSymbol* curr, SgStatement* location);
void fillUsedModulesInFunction(SgStatement* st, std::vector<SgStatement*>& useStats); void fillUsedModulesInFunction(SgStatement* st, std::vector<SgStatement*>& useStats);
void fillUseStatement(SgStatement* st, std::set<std::string>& useMod, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUse, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUseOnly); void fillUseStatement(SgStatement* st, std::set<std::string>& useMod, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUse, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUseOnly);
void fixUseOnlyStmt(SgFile* file, const std::vector<ParallelRegion*>& regs); void fixUseOnlyStmt(SgFile* file, const std::vector<ParallelRegion*>& regs);

View File

@@ -1,3 +1,3 @@
#pragma once #pragma once
#define VERSION_SPF "2391" #define VERSION_SPF "2392"