fixed and improved inlining

This commit is contained in:
ALEXks
2023-11-22 20:21:18 +03:00
parent afdbfdac61
commit 0ba8915fa0
12 changed files with 210 additions and 105 deletions

View File

@@ -39,8 +39,8 @@ using std::stack;
#define DEBUG 0
//TODO: improve parameter checking
static void correctNameIfContains(SgStatement *call, SgExpression *exCall, string &name,
const vector<SgStatement*> &containsFunctions, const string &prefix)
void correctNameIfContains(SgStatement *call, SgExpression *exCall, string &name,
const vector<SgStatement*> &containsFunctions, const string &prefix)
{
if (containsFunctions.size() <= 0)
return;

View File

@@ -219,6 +219,27 @@ struct FuncInfo
}
bool usesIO() const { return (linesOfIO.size() != 0 || linesOfStop.size() != 0); }
std::string getCallName(const std::pair<void*, int>& call_info, const std::string& name, int line)
{
std::set<std::string> names;
for (auto& call : callsFromDetailed)
{
if (call.pointerDetailCallsFrom == call_info && call.detailCallsFrom.second == line)
return call.detailCallsFrom.first;
if (call.detailCallsFrom.second == line)
if (call.detailCallsFrom.first.find(name) != std::string::npos)
names.insert(call.detailCallsFrom.first);
}
//TODO: detect func call better
if (names.size() == 1)
return *names.begin();
return "";
}
};
struct CallV

View File

@@ -46,6 +46,7 @@ void detectCopies(std::map<std::string, std::vector<FuncInfo*>> &allFuncInfo);
void fillInterfaceBlock(std::map<std::string, std::vector<FuncInfo*>>& allFuncInfo);
parF detectExpressionType(SgExpression* exp);
void findContainsFunctions(SgStatement *st, std::vector<SgStatement*> &found, const bool searchAll = false);
void correctNameIfContains(SgStatement* call, SgExpression* exCall, std::string& name, const std::vector<SgStatement*>& containsFunctions, const std::string& prefix);
int countPerfectLoopNest(SgStatement* st);
void setInlineAttributeToCalls(const std::map<std::string, FuncInfo*>& allFunctions, const std::map<std::string, std::set<std::pair<std::string, int>>>& inDataChains, const std::map<std::string, std::vector<SgStatement*>>& hiddenData);
#endif

View File

@@ -1141,6 +1141,8 @@ static inline void replaceCall(SgExpression* exp, SgExpression* par, const int i
SgAssignStmt* assign = new SgAssignStmt(*new SgVarRefExp(*v), *exp->copyPtr());
assign->setlineNumber(getNextNegativeLineNumber());
assign->setLocalLineNumber(callSt->lineNumber());
insertPlace->insertStmtBefore(*assign, *callSt->controlParent());
// replace function call to a new variable
@@ -1150,16 +1152,18 @@ static inline void replaceCall(SgExpression* exp, SgExpression* par, const int i
lhs ? par->setLhs(new SgVarRefExp(*v)) : par->setRhs(new SgVarRefExp(*v));
}
static void recFindFuncCall(SgExpression* exp, SgExpression* par, const int i, const bool lhs,
static void recFindFuncCall(FuncInfo* currentFuncI,
SgExpression* exp, SgExpression* par, const int i, const bool lhs,
const string& funcName, bool& foundCall,
SgStatement* callSt, set<SgSymbol*>& newSymbols, SgStatement* insertPlace)
{
if (exp)
{
recFindFuncCall(exp->rhs(), exp, i, false, funcName, foundCall, callSt, newSymbols, insertPlace);
recFindFuncCall(exp->lhs(), exp, i, true, funcName, foundCall, callSt, newSymbols, insertPlace);
recFindFuncCall(currentFuncI, exp->rhs(), exp, i, false, funcName, foundCall, callSt, newSymbols, insertPlace);
recFindFuncCall(currentFuncI, exp->lhs(), exp, i, true, funcName, foundCall, callSt, newSymbols, insertPlace);
if (exp->variant() == FUNC_CALL && exp->symbol() && exp->symbol()->identifier() == funcName)
if (exp->variant() == FUNC_CALL && exp->symbol() &&
currentFuncI->getCallName(make_pair(exp, exp->variant()), exp->symbol()->identifier(), callSt->lineNumber()) == funcName)
{
foundCall = true;
if (par) // do not extract external func call
@@ -1211,7 +1215,7 @@ static SgType* getTrueType(SgType* inExp, parF funcParType)
return inExp;
}
static inline void PrecalculateActualParameters(SgStatement* st, SgSymbol* s, SgExpression* e,
static inline void PrecalculateActualParameters(SgStatement* st, SgExpression* e,
const FuncInfo* func, set<SgSymbol*>& newSymbols)
{
// Precalculate actual parameter expressions
@@ -1353,7 +1357,7 @@ static bool run_inliner(const map<string, FuncInfo*>& funcMap, set<SgStatement*>
for (auto& callSt : toInsert)
{
SgStatement* insertPlace = callSt;
SgStatement* currentFunc = getFuncStat(callSt);
SgProgHedrStmt* currentFunc = (SgProgHedrStmt*) getFuncStat(callSt);
if (usedByFunc.find(currentFunc) == usedByFunc.end())
{
@@ -1370,11 +1374,12 @@ static bool run_inliner(const map<string, FuncInfo*>& funcMap, set<SgStatement*>
usedByFunc[currentFunc] = used;
}
auto itF = funcMap.find(currentFunc->symbol()->identifier());
auto itF = funcMap.find(currentFunc->nameWithContains());
if (itF == funcMap.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
renameArgsIfGlobalNameIntersection(itF->second, globalNames);
FuncInfo* currentFuncI = itF->second;
renameArgsIfGlobalNameIntersection(currentFuncI, globalNames);
set<string> useStatsInFunc;
for (SgStatement* s = currentFunc; s != currentFunc->lastNodeOfStmt(); s = s->lexNext())
@@ -1439,7 +1444,7 @@ static bool run_inliner(const map<string, FuncInfo*>& funcMap, set<SgStatement*>
// 1.a: make statement preprocessing
// if call statement contains several inlining functions, split every such call
for (int i = 0; i < 3; ++i)
recFindFuncCall(callSt->expr(i), NULL, i, false, funcName, foundCall, callSt, newSymbols, insertPlace);
recFindFuncCall(currentFuncI, callSt->expr(i), NULL, i, false, funcName, foundCall, callSt, newSymbols, insertPlace);
__spf_print(DEB, "---argument preprocessing---\n"); // DEBUG
// 1.b: make argument preprocessing
@@ -1451,18 +1456,23 @@ static bool run_inliner(const map<string, FuncInfo*>& funcMap, set<SgStatement*>
{
if (st->variant() == ASSIGN_STAT)
{
auto s = st->expr(1)->symbol();
if (s && s->identifier() == funcName)
auto rPart = st->expr(1);
int line = st->lineNumber() < 0 ? st->localLineNumber() : st->lineNumber();
if (rPart->variant() == FUNC_CALL && rPart->symbol() &&
currentFuncI->getCallName(make_pair(rPart, rPart->variant()), rPart->symbol()->identifier(), line) == funcName)
{
if (isSgVarRefExp(st->expr(0)) || isSgArrayRefExp(st->expr(0)) && !isSgArrayType(st->expr(0)->type()))
PrecalculateActualParameters(st, s, st->expr(1)->lhs(), func, newSymbols);
PrecalculateActualParameters(st, st->expr(1)->lhs(), func, newSymbols);
}
}
else if (st->variant() == PROC_STAT)
{
if (st->symbol() && st->symbol()->identifier() == funcName)
if (st->symbol() &&
currentFuncI->getCallName(make_pair(st, st->variant()), st->symbol()->identifier(), st->lineNumber()) == funcName)
{
foundCall = true;
if (st->expr(0))
PrecalculateActualParameters(st, st->symbol(), st->expr(0), func, newSymbols);
PrecalculateActualParameters(st, st->expr(0), func, newSymbols);
}
}
}
@@ -1483,33 +1493,34 @@ static bool run_inliner(const map<string, FuncInfo*>& funcMap, set<SgStatement*>
switch (st->variant())
{
case ASSIGN_STAT:
{
auto rPart = st->expr(1);
if (rPart->variant() == FUNC_CALL && rPart->symbol() && rPart->symbol()->identifier() == funcName)
{
bool doInline = insert(st, funcStat, rPart->lhs(), newSymbols, funcMap, toDelete, useStats, SPF_messages, point);
change |= doInline;
isInlined |= doInline;
auto rPart = st->expr(1);
int line = st->lineNumber() < 0 ? st->localLineNumber() : st->lineNumber();
if (rPart->variant() == FUNC_CALL && rPart->symbol() &&
currentFuncI->getCallName(make_pair(rPart, rPart->variant()), rPart->symbol()->identifier(), line) == funcName)
{
bool doInline = insert(st, funcStat, rPart->lhs(), newSymbols, funcMap, toDelete, useStats, SPF_messages, point);
change |= doInline;
isInlined |= doInline;
}
}
}
continue;
break;
case PROC_STAT:
if (st->symbol() && st->symbol()->identifier() == funcName)
if (st->symbol() &&
currentFuncI->getCallName(make_pair(st, st->variant()), st->symbol()->identifier(), st->lineNumber()) == funcName)
{
bool doInline = insert(st, funcStat, st->expr(0), newSymbols, funcMap, toDelete, useStats, SPF_messages, point);
change |= doInline;
isInlined |= doInline;
}
continue;
break;
default:
continue;
break;
}
}
for (auto& st : toDelete)
{
st->extractStmt();
}
}
}
@@ -1686,7 +1697,7 @@ bool inliner(const string& fileName_in, const string& funcName, const int lineNu
}
if (markers.size() == 0)
return 0;
return false;
PointCall point;
point.mainPoint.first = fileName;
@@ -1696,7 +1707,7 @@ bool inliner(const string& fileName_in, const string& funcName, const int lineNu
point.currLvl = 0;
point.currCall = func->funcName;
__spf_print(1, "INLINE %s\n", func->funcName.c_str());
__spf_print(1, " INLINE %s\n", func->funcName.c_str());
#ifdef _WIN32
sendMessage_2lvl(wstring(L"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> '") + wstring(func->funcName.begin(), func->funcName.end()) + L"'");
#else
@@ -1705,6 +1716,12 @@ bool inliner(const string& fileName_in, const string& funcName, const int lineNu
//1 level
bool isInlined = run_inliner(funcMap, toInsert, SPF_messages, fileName, func, newSymbsToDeclare, point, commonBlocks);
if (isInlined == false)
{
__spf_print(1, " missing ...\n");
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
if (deepLvl >= 0 && isInlined)
{
int currDeep = 0;
@@ -1729,7 +1746,7 @@ bool inliner(const string& fileName_in, const string& funcName, const int lineNu
point.currLvl = currDeep;
point.currCall = next.first->funcName;
__spf_print(1, "INLINE %s\n", next.first->funcName.c_str());
__spf_print(1, " INLINE %s\n", next.first->funcName.c_str());
bool isInlined = run_inliner(funcMap, next.second, SPF_messages, fileName, next.first, newSymbsToDeclare, point, commonBlocks);
changed |= isInlined;
}

View File

@@ -1150,6 +1150,8 @@ static bool runAnalysis(SgProject &project, const int curr_regime, const bool ne
for (SgStatement* st = file->firstStatement(); st; st = st->lexNext())
removeOmpDir(st);
}
else if (curr_regime == GET_MIN_MAX_BLOCK_DIST)
getMaxMinBlockDistribution(file, min_max_block);
else if (curr_regime == TEST_PASS)
{
//test pass
@@ -2034,6 +2036,49 @@ static bool runAnalysis(SgProject &project, const int curr_regime, const bool ne
if (inDataProc.size())
{
// if inlineI -> fixed lines
if (inDataChains.size() == 0 && inDataChainsStart.size() == 0)
{
for (int z = 0; z < inDataProc.size(); ++z)
{
if (std::get<2>(inDataProc[z]) > 0)
continue;
auto funcToInl = std::get<0>(inDataProc[z]);
auto file = std::get<1>(inDataProc[z]);
int absoluteLine = 0;
int shilftLine = -std::get<2>(inDataProc[z]);
for (auto& funcByFile : allFuncInfo)
{
if (funcByFile.first != file)
continue;
for (auto& func : funcByFile.second)
{
int targetLine = func->linesNum.first + shilftLine;
__spf_print(1, "%s target %d + %d = %d\n", func->funcName.c_str(), func->linesNum.first, shilftLine, targetLine);
for (auto& detCall : func->callsFromDetailed)
{
__spf_print(1, "%s %d\n", detCall.detailCallsFrom.first.c_str(), detCall.detailCallsFrom.second);
if (detCall.detailCallsFrom == make_pair(funcToInl, targetLine))
{
absoluteLine = targetLine;
break;
}
}
if (absoluteLine)
break;
}
}
if (absoluteLine == 0)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
std::get<2>(inDataProc[z]) = absoluteLine;
}
}
map<int, vector<int>> sortByLvl;
int maxLvlCall = 0;
@@ -2074,8 +2119,13 @@ static bool runAnalysis(SgProject &project, const int curr_regime, const bool ne
if (std::get<2>(tup) != -1)
{
__spf_print(1, "call inliner with [%s %s %d]\n", std::get<1>(tup).c_str(), std::get<0>(tup).c_str(), std::get<2>(tup));
inliner(std::get<1>(tup), std::get<0>(tup), std::get<2>(tup), allFuncInfo, SPF_messages, newSymbsToDeclare, commonBlocks);
__spf_print(1, " call inliner with [%s %s %d]\n", std::get<1>(tup).c_str(), std::get<0>(tup).c_str(), std::get<2>(tup));
bool isInlined = inliner(std::get<1>(tup), std::get<0>(tup), std::get<2>(tup), allFuncInfo, SPF_messages, newSymbsToDeclare, commonBlocks);
if (!isInlined)
{
__spf_print(1, " missing ...\n");
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
}
}
}
@@ -2192,6 +2242,10 @@ static bool runAnalysis(SgProject &project, const int curr_regime, const bool ne
SelectArrayConfForParallelization(&project, allFuncInfo, loopGraph, createdDirectives, SPF_messages, arrayLinksByFuncCalls, parallelRegions);
removeRegionsWithoutDirs(createdDirectives, parallelRegions, allFuncInfo, SPF_messages);
}
else if (curr_regime == GET_MIN_MAX_BLOCK_DIST)
{
__spf_print(1, "GET_MIN_MAX_BLOCK_DIST: %d %d\n", min_max_block.first, min_max_block.second);
}
const float elapsed = duration_cast<milliseconds>(high_resolution_clock::now() - timeForPass).count() / 1000.;
const float elapsedGlobal = duration_cast<milliseconds>(high_resolution_clock::now() - globalTime).count() / 1000.;
@@ -2864,7 +2918,7 @@ int main(int argc, char **argv)
listOfProject.push_back(FileInfo(file, toAddOpt + "-o " + file + ".dep", "", "", "", fileText, 0));
}
int rethrow = parseFiles(errors, listOfProject, filesCompilationOrder, 0, true);
int rethrow = parseFiles(errors, listOfProject, filesCompilationOrder, 1, true);
if (rethrow == 0)
{
for (auto& err : errors)

View File

@@ -168,6 +168,7 @@ enum passes {
FIX_COMMON_BLOCKS,
REMOVE_OMP_DIRS,
GET_MIN_MAX_BLOCK_DIST,
TEST_PASS,
EMPTY_PASS
@@ -341,6 +342,7 @@ static void setPassValues()
passNames[FIX_COMMON_BLOCKS] = "FIX_COMMON_BLOCKS";
passNames[REMOVE_OMP_DIRS] = "REMOVE_OMP_DIRS";
passNames[GET_MIN_MAX_BLOCK_DIST] = "GET_MIN_MAX_BLOCK_DIST";
passNames[TEST_PASS] = "TEST_PASS";
}

View File

@@ -168,6 +168,9 @@ std::map<int, UserFiles> filesInfo; // information about open,close,write and re
std::map< std::pair<std::string, int>, std::set<std::string>> parametersOfProject; // [file, line] -> set[vars]
//
//for GET_MIN_MAX_BLOCK_DIST
std::pair<int, int> min_max_block = std::make_pair(-1, -1);
//
const char* passNames[EMPTY_PASS + 1];
const char* optionNames[EMPTY_OPTION + 1];
bool passNamesWasInit = false;

View File

@@ -300,6 +300,8 @@ void InitPassesDependencies(map<passes, vector<passes>> &passDepsIn, set<passes>
list({ BUILD_IR, LOOP_GRAPH, LIVE_ANALYSIS_IR }) <= Pass(PRIVATE_ANALYSIS_IR);
Pass(FILE_LINE_INFO) <= Pass(GET_MIN_MAX_BLOCK_DIST);
Pass(CALL_GRAPH2) <= Pass(FIX_COMMON_BLOCKS);
passesIgnoreStateDone.insert({ CREATE_PARALLEL_DIRS, INSERT_PARALLEL_DIRS, INSERT_SHADOW_DIRS, EXTRACT_PARALLEL_DIRS,

View File

@@ -12,6 +12,7 @@
#include <vector>
#include <map>
#include <queue>
#include <set>
#include <utility>
#include <string>
@@ -45,6 +46,7 @@
#include "../LoopAnalyzer/loop_analyzer.h"
using std::map;
using std::queue;
using std::multimap;
using std::pair;
using std::tuple;
@@ -3144,6 +3146,17 @@ static inline void restoreOriginalText(const FileInfo& file)
writeFileFromStr(file.fileName, file.text);
}
static void checkRetCode(FileInfo& info, const string& errorMessage)
{
if (info.error != 0)
info.lvl++;
if (errorMessage.find("Warning 308") != string::npos)
if (info.error == 0)
info.error = 1;
}
extern "C" int parse_file(int argc, char* argv[], char* proj_name);
static vector<string> parseList(vector<FileInfo>& listOfProject,
bool needToInclude, bool needToIncludeForInline,
@@ -3221,8 +3234,8 @@ static vector<string> parseList(vector<FileInfo>& listOfProject,
StdCapture::BeginCapture();
if (needToInclude)
filesModified = applyModuleDeclsForFile(&elem, mapFiles, moduleDelc, mapModuleDeps, modDirectOrder, optSplited, needToIncludeForInline);
else if (needToIncludeForInline) //TODO for modules
filesModified = applyModuleDeclsForFile(&elem, mapFiles, moduleDelc, mapModuleDeps, modDirectOrder, optSplited, true);
else if (needToIncludeForInline) // TODO for modules
filesModified = applyModuleDeclsForFile(&elem, mapFiles, moduleDelc, mapModuleDeps, modDirectOrder, optSplited, needToIncludeForInline);
int retCode = parse_file(optSplited.size(), toParse, "dvm.proj");
if (needToInclude || needToIncludeForInline)
@@ -3235,13 +3248,7 @@ static vector<string> parseList(vector<FileInfo>& listOfProject,
elem.error = retCode;
StdCapture::EndCapture();
errorMessage = StdCapture::GetCapture();
if (elem.error != 0)
elem.lvl++;
if (errorMessage.find("Warning 308") != string::npos)
if (elem.error == 0)
elem.error = 1;
checkRetCode(elem, errorMessage);
}
catch (int err)
{
@@ -4386,3 +4393,53 @@ void removeSpecialCommentsFromProject(SgFile* file)
stF = stF->lexNext();
}
}
void getMaxMinBlockDistribution(SgFile* file, pair<int, int>& min_max)
{
SgStatement* st = file->firstStatement();
while (st)
{
if (isDVM_stat(st))
{
if (st->variant() == DVM_DISTRIBUTE_DIR || st->variant() == DVM_VAR_DECL)
{
for (int z = 0; z < 3; ++z)
{
SgExpression* ex = st->expr(z);
queue<SgExpression*> q;
if (ex)
{
q.push(ex);
int blockCount = 0;
while (q.size())
{
ex = q.front();
q.pop();
if (ex->rhs())
q.push(ex->rhs());
if (ex->lhs())
q.push(ex->lhs());
if (ex->variant() == BLOCK_OP)
blockCount++;
}
if (blockCount)
{
if (min_max == make_pair(-1, -1))
min_max = make_pair(blockCount, blockCount);
else
{
min_max.first = std::min(min_max.first, blockCount);
min_max.second = std::max(min_max.second, blockCount);
}
}
}
}
}
}
st = st->lexNext();
}
}

View File

@@ -107,4 +107,6 @@ bool isEqSymbols(SgSymbol* sym1, SgSymbol* sym2);
std::set<std::string> getAllFilesInProject();
void LogIftoIfThen(SgStatement* stmt);
void removeSpecialCommentsFromProject(SgFile* file);
void removeSpecialCommentsFromProject(SgFile* file);
void getMaxMinBlockDistribution(SgFile* file, std::pair<int, int>& min_max);

View File

@@ -1,3 +1,3 @@
#pragma once
#define VERSION_SPF "2240"
#define VERSION_SPF "2245"

View File

@@ -1262,6 +1262,7 @@ int SPF_GetArrayLinks(void*& context, int winHandler, short *options, short *pro
return retSize;
}
extern std::pair<int, int> min_max_block;
int SPF_GetMaxMinBlockDistribution(void*& context, int winHandler, short *options, short *projName, short *&result, short *&output, int *&outputSize,
short *&outputMessage, int *&outputMessageSize)
{
@@ -1273,64 +1274,10 @@ int SPF_GetMaxMinBlockDistribution(void*& context, int winHandler, short *option
int retSize = -1;
try
{
runPassesForVisualizer(projName, { FILE_LINE_INFO });
runPassesForVisualizer(projName, { GET_MIN_MAX_BLOCK_DIST });
string resVal = "";
int minBlock = 10;
int maxBlock = 0;
for (int z = 0; z < CurrentProject->numberOfFiles(); ++z)
{
SgFile* file = &(CurrentProject->file(z));
SgStatement* st = file->firstStatement();
while (st)
{
if (isDVM_stat(st))
{
if (st->variant() == DVM_DISTRIBUTE_DIR || st->variant() == DVM_VAR_DECL)
{
for (int z = 0; z < 3; ++z)
{
SgExpression* ex = st->expr(z);
queue<SgExpression*> q;
if (ex)
{
q.push(ex);
int blockCount = 0;
while (q.size())
{
ex = q.front();
q.pop();
if (ex->rhs())
q.push(ex->rhs());
if (ex->lhs())
q.push(ex->lhs());
if (ex->variant() == BLOCK_OP)
blockCount++;
}
if (blockCount)
{
minBlock = std::min(minBlock, blockCount);
maxBlock = std::max(maxBlock, blockCount);
}
}
}
}
}
st = st->lexNext();
}
}
if (minBlock == 10 && maxBlock == 0)
minBlock = maxBlock = 0;
resVal = to_string(minBlock) + " " + to_string(maxBlock);
string resVal = "";
resVal = to_string(min_max_block.first) + " " + to_string(min_max_block.second);
copyStringToShort(result, resVal);
retSize = (int)resVal.size() + 1;
@@ -1979,7 +1926,6 @@ static int inline runModificationPass(passes passName, short* projName, short* f
}
extern tuple<string, int, int, int> inData;
extern map<string, string> outData;
int SPF_ChangeSpfIntervals(void*& context, int winHandler, short *options, short *projName, short *folderName, short *&output,
int *&outputSize, short *&outputMessage, int *&outputMessageSize,
short *fileNameToMod, int *toModifyLines,
@@ -2133,7 +2079,7 @@ int SPF_InlineProcedures(void*& context, int winHandler, short* options, short*
if (result.size() < 2)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (int z = 0; z < result.size();)
for (int z = 0; z < result.size(); )
{
string procName = result[z++];
int count = -1;