Files
SAPFOR/src/Distribution/DvmhDirective.cpp

1316 lines
46 KiB
C++
Raw Normal View History

2023-09-14 19:43:13 +03:00
#include "../Utils/leak_detector.h"
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <vector>
#include <string>
#include <algorithm>
#include "../Utils/types.h"
#include "DvmhDirective.h"
#include "../Utils/errors.h"
#include "../Utils/SgUtils.h"
#include "../Sapfor.h"
#include "../GraphCall/graph_calls_func.h"
#include "dvm.h"
using std::vector;
using std::tuple;
using std::get;
using std::string;
using std::pair;
using std::set;
using std::map;
using std::set_union;
using std::make_pair;
using std::min;
using std::max;
extern map<tuple<int, string, string>, pair<DIST::Array*, DIST::ArrayAccessInfo*>> declaredArrays;
2023-09-14 19:43:13 +03:00
static bool findArrayRefAndCheck(SgExpression *ex, const DIST::Array* currArray, const vector<map<pair<int, int>, int>> &shiftsByAccess)
{
bool res = false;
if (ex)
{
if (ex->variant() == ARRAY_REF)
{
if (ex->symbol() && OriginalSymbol(ex->symbol())->identifier() == currArray->GetShortName())
{
SgArrayRefExp *ref = (SgArrayRefExp*)ex;
int countOfShadows = 0;
for (int i = 0; i < ref->numberOfSubscripts(); ++i)
{
const vector<int*> &coefs = getAttributes<SgExpression*, int*>(ref->subscript(i), set<int>{ INT_VAL });
if (coefs.size() == 1)
{
const pair<int, int> coef(coefs[0][0], coefs[0][1]);
auto it = shiftsByAccess[i].find(coef);
if (it != shiftsByAccess[i].end())
if (it->second != 0)
countOfShadows++;
}
}
if (countOfShadows > 1)
return true;
}
}
if (ex->lhs())
{
bool tmp = findArrayRefAndCheck(ex->lhs(), currArray, shiftsByAccess);
res = res || tmp;
}
if (ex->rhs())
{
bool tmp = findArrayRefAndCheck(ex->rhs(), currArray, shiftsByAccess);
res = res || tmp;
}
}
return res;
}
static bool needCorner(const DIST::Array* currArray, const vector<map<pair<int, int>, int>> &shiftsByAccess, Statement *loop)
2023-09-14 19:43:13 +03:00
{
bool need = false;
SgStatement *orig = loop->GetOriginal();
for (auto st = orig; st != orig->lastNodeOfStmt() && !need; st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT)
need = findArrayRefAndCheck(st->expr(1), currArray, shiftsByAccess);
else
{
for (int i = 0; i < 3; ++i)
need = need || findArrayRefAndCheck(st->expr(i), currArray, shiftsByAccess);
}
}
return need;
}
vector<SgExpression*> genSubscripts(const vector<pair<int, int>> &shadowRenew, const vector<pair<int, int>> &shadowRenewShifts)
{
vector<SgExpression*> subs;
for (int z = 0; z < shadowRenew.size(); ++z)
{
SgValueExp *tmp = new SgValueExp(shadowRenew[z].first + shadowRenewShifts[z].first);
SgValueExp *tmp1 = new SgValueExp(shadowRenew[z].second + shadowRenewShifts[z].second);
subs.push_back(new SgExpression(DDOT, tmp, tmp1, NULL));
}
return subs;
}
SgExpression* createAndSetNext(const int side, const int variant, SgExpression *p)
{
if (side == LEFT)
{
SgExpression *tmp = new SgExpression(variant);
p->setLhs(tmp);
return p->lhs();
}
else if (side == RIGHT)
{
SgExpression *tmp = new SgExpression(variant);
p->setRhs(tmp);
return p->rhs();
}
return NULL;
}
static SgExpression* genComplexExpr(const pair<string, string> &digitConv, const int digit)
{
SgExpression *tmp;
if (digitConv.first == " - ")
tmp = new SgUnaryExp(MINUS_OP, *new SgValueExp(-digit));
else
tmp = new SgValueExp(digit);
return tmp;
}
static SgExpression* genSgExpr(SgFile *file, const string &letter, const pair<int, int> expr)
{
SgExpression *retVal;
SgSymbol *symbLetter = findSymbolOrCreate(file, letter);
if (expr.first == 0 && expr.second == 0)
retVal = new SgVarRefExp(findSymbolOrCreate(file, "*"));
else if (expr.second == 0)
{
if (expr.first == 1)
retVal = new SgVarRefExp(symbLetter);
else
{
pair<string, string> digit2 = convertDigitToPositive(expr.first);
SgVarRefExp *tmp = new SgVarRefExp(symbLetter);
retVal = new SgExpression(MULT_OP, genComplexExpr(digit2, expr.first), tmp, NULL);
}
}
else
{
pair<string, string> digit1 = convertDigitToPositive(expr.second);
SgExpression *d1 = genComplexExpr(digit1, expr.second);
if (expr.first == 1)
{
SgVarRefExp *tmp = new SgVarRefExp(symbLetter);
retVal = new SgExpression(ADD_OP, tmp, d1, NULL);
}
else
{
pair<string, string> digit2 = convertDigitToPositive(expr.first);
SgExpression *d2 = genComplexExpr(digit2, expr.first);
SgVarRefExp *tmp = new SgVarRefExp(symbLetter);
SgExpression *tmp1 = new SgExpression(MULT_OP, d2, tmp, NULL);
retVal = new SgExpression(ADD_OP, tmp1, d1, NULL);
}
}
return retVal;
}
static void fillUsedSymbols(SgExpression* ex, set<SgSymbol*>& used)
{
if (ex)
{
if (isArrayRef(ex) || ex->variant() == VAR_REF)
used.insert(ex->symbol());
fillUsedSymbols(ex->lhs(), used);
fillUsedSymbols(ex->rhs(), used);
}
}
static set<string> fillUsedSymbols(SgStatement *loop)
{
set<SgSymbol*> used;
SgStatement* last = loop->lastNodeOfStmt();
for (SgStatement* st = loop->lexNext(); st != last; st = st->lexNext())
for (int z = 0; z < 3; ++z)
if (st->expr(z))
fillUsedSymbols(st->expr(z), used);
set<string> usedS;
for (auto& elem : used)
usedS.insert(elem->identifier());
return usedS;
}
static SgStatement* getModuleScope(const string& origFull, vector<SgStatement*>& moduleList, SgStatement *local)
2023-09-14 19:43:13 +03:00
{
auto it = origFull.find("::");
if (it == string::npos)
return local;
string modName = origFull.substr(0, it);
for (auto& elem : moduleList)
if (elem->symbol()->identifier() == modName)
return elem;
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
return local;
}
static vector<SgExpression*>
compliteTieList(const LoopGraph* currLoop, const vector<LoopGraph*>& loops,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls,
2025-02-18 13:45:20 +03:00
File* file, SgStatement *location,
const set<DIST::Array*>& onlyFor,
const set<string>& privates)
2023-09-14 19:43:13 +03:00
{
vector<SgExpression*> tieList;
vector<pair<DIST::Array*, DIST::Array*>> realRefsUsed;
const auto& usedArrays = sharedMemoryParallelization ? currLoop->usedArraysAll : currLoop->usedArrays;
for (auto& elem : usedArrays)
2023-09-14 19:43:13 +03:00
{
if (onlyFor.size())
if (onlyFor.find(elem) == onlyFor.end())
continue;
set<DIST::Array*> realRefs;
getRealArrayRefs(elem, elem, realRefs, arrayLinksByFuncCalls);
if (realRefs.size() == 0)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
realRefsUsed.push_back(make_pair(*realRefs.begin(), elem));
}
if (realRefsUsed.size() == 0)
return tieList;
SgVarRefExp* zeroS = new SgVarRefExp(findSymbolOrCreate(file, "*"));
for (auto& pairs : realRefsUsed)
{
if (privates.find(pairs.second->GetShortName()) != privates.end())
continue;
2025-02-18 13:45:20 +03:00
SgSymbol* arrayS = (SgSymbol*)pairs.second->GetNameInLocationS(location);
2023-09-14 19:43:13 +03:00
SgArrayRefExp* array = new SgArrayRefExp(*arrayS);
bool needToAdd = false;
vector<SgExpression*> subs;
for (int k = 0; k < pairs.second->GetDimSize(); ++k)
subs.push_back(&zeroS->copy());
for (int z = 0; z < loops.size(); ++z)
{
currLoop = loops[z];
if(!sharedMemoryParallelization)
2023-09-14 19:43:13 +03:00
{
const uint64_t regId = sharedMemoryParallelization ? (uint64_t)currLoop : currLoop->region->GetId();
auto dirForLoop = currLoop->directiveForLoop;
auto tmplP = pairs.first->GetTemplateArray(regId, sharedMemoryParallelization != 0);
auto links = pairs.first->GetLinksWithTemplate(regId);
// no mapping for this loop, skip this
if (tmplP == dirForLoop->arrayRef)
2023-09-14 19:43:13 +03:00
{
for (int z = 0; z < links.size(); ++z)
2023-09-14 19:43:13 +03:00
{
int dim = links[z];
if (dim >= 0)
2023-09-14 19:43:13 +03:00
{
if (dirForLoop->on[dim].first != "*")
{
needToAdd = true;
subs[z] = new SgVarRefExp(findSymbolOrCreate(file, dirForLoop->on[dim].first));
break;
}
2023-09-14 19:43:13 +03:00
}
}
}
else if (pairs.second == dirForLoop->arrayRef)
2023-09-14 19:43:13 +03:00
{
for (int z = 0; z < dirForLoop->on.size(); ++z)
2023-09-14 19:43:13 +03:00
{
if (dirForLoop->on[z].first != "*")
{
needToAdd = true;
subs[z] = new SgVarRefExp(findSymbolOrCreate(file, dirForLoop->on[z].first));
break;
}
2023-09-14 19:43:13 +03:00
}
}
else if (!dirForLoop->arrayRef->IsTemplate())
{
set<DIST::Array*> realRefsLocal;
getRealArrayRefs(dirForLoop->arrayRef, dirForLoop->arrayRef, realRefsLocal, arrayLinksByFuncCalls);
2023-09-14 19:43:13 +03:00
if (realRefsLocal.size() == 0)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
2023-09-14 19:43:13 +03:00
auto tmplP = (*realRefsLocal.begin())->GetTemplateArray(regId, sharedMemoryParallelization != 0);
auto links = (*realRefsLocal.begin())->GetLinksWithTemplate(regId);
2023-09-14 19:43:13 +03:00
auto tmplP_et = pairs.first->GetTemplateArray(regId, sharedMemoryParallelization != 0);
auto links_et = pairs.first->GetLinksWithTemplate(regId);
2023-09-14 19:43:13 +03:00
if (tmplP == tmplP_et)
2023-09-14 19:43:13 +03:00
{
for (int z = 0; z < dirForLoop->on.size(); ++z)
{
if (dirForLoop->on[z].first != "*")
{
const int idx = links[z];
for (int p = 0; p < links_et.size(); ++p)
2023-09-14 19:43:13 +03:00
{
if (idx >= 0 && links_et[p] == idx)
{
subs[p] = new SgVarRefExp(findSymbolOrCreate(file, dirForLoop->on[z].first));
needToAdd = true;
break;
}
2023-09-14 19:43:13 +03:00
}
break;
2023-09-14 19:43:13 +03:00
}
}
}
}
}
else
{
for (const auto& source : { currLoop->readOpsForLoop, currLoop->writeOpsForLoop }) {
auto array_it = source.find(pairs.second);
if (array_it != source.end()) {
bool dim_found = false;
for (int i = 0; i < array_it->second.size(); i++) {
if (array_it->second[i].coefficients.size() != 0)
{
needToAdd = true;
dim_found = true;
subs[i] = new SgVarRefExp(findSymbolOrCreate(file, currLoop->loopSymbol));
break;
}
}
if (dim_found)
break;
}
}
}
2023-09-14 19:43:13 +03:00
}
if (needToAdd)
{
for (int k = 0; k < subs.size(); ++k)
array->addSubscript(*subs[k]);
tieList.push_back(array);
}
}
return tieList;
}
//TODO: need to improve
static set<SgSymbol*> fillPrivateOnlyFromSpfParameter(SgStatement* loop, const int altLine)
2023-09-14 19:43:13 +03:00
{
set<SgSymbol*> used;
set<SgSymbol*> usedInSpfPar;
SgStatement* last = loop->lastNodeOfStmt();
for (SgStatement* st = loop->lexNext(); st != last; st = st->lexNext())
{
bool isSpf = false;
auto attrSpfPar = getAttributes<SgStatement*, SgStatement*>(st, set<int>{ SPF_PARAMETER_OP });
if (attrSpfPar.size()) // SPF PARAMETER
isSpf = true;
for (int z = 0; z < 3; ++z)
if (st->expr(z))
fillUsedSymbols(st->expr(z), isSpf ? usedInSpfPar : used);
}
return usedInSpfPar;
}
static set<SgSymbol*> changeLoopOrder(const vector<string>& parallel, const vector<string>& newParallel, vector<LoopGraph*>& loops)
2023-09-14 19:43:13 +03:00
{
set<SgSymbol*> additionalPrivates;
if (parallel == newParallel)
return additionalPrivates;
vector<int> newOrder, order;
vector<int> newPosition(parallel.size());
for (int z = 0; z < parallel.size(); ++z)
order.push_back(z);
for (int z = 0; z < parallel.size(); ++z)
{
if (parallel[z] != "*")
{
newOrder.push_back(z);
newPosition[z] = newOrder.size() - 1;
}
}
for (int z = 0; z < parallel.size(); ++z)
{
if (parallel[z] == "*")
{
newOrder.push_back(z);
newPosition[z] = newOrder.size() - 1;
}
}
int idxFristParallel = parallel.size();
for (int z = 0; z < parallel.size(); ++z)
if (parallel[z] != "*")
idxFristParallel = MIN(idxFristParallel, newPosition[z]);
if (idxFristParallel == parallel.size())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (int z = 0; z < parallel.size(); ++z)
if (parallel[z] == "*")
if (newPosition[z] > idxFristParallel)
additionalPrivates.insert(loops[z]->loop->symbol());
for (int z = 0; z < order.size(); ++z)
{
if (order[z] != newOrder[z])
{
int idx = 0;
for (; idx < order.size(); ++idx)
if (newOrder[z] == order[idx])
break;
if (loops[z]->getForSwap() == NULL)
loops[z]->setForSwap(loops[idx]);
else if (loops[idx]->getForSwap() == NULL)
loops[idx]->setForSwap(loops[z]);
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
std::swap(order[z], order[idx]);
}
}
return additionalPrivates;
}
static vector<int> sortShadow(const vector<pair<pair<string, string>, vector<pair<int, int>>>>& toSort)
2023-09-14 19:43:13 +03:00
{
map<string, int> order;
for (int z = 0; z < toSort.size(); ++z)
order[toSort[z].first.second] = z;
if (order.size() != toSort.size())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
vector<int> idxs;
for (auto& elem : order)
idxs.push_back(elem.second);
return idxs;
}
Directive*
ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, const DistrVariant*>>& distribution,
LoopGraph* currLoop,
DIST::GraphCSR<int, double, attrType>& reducedG,
DIST::Arrays<int>& allArrays, const uint64_t regionId,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls)
{
const set<DIST::Array*>& acrossOutAttribute = currLoop->acrossOutAttribute;
const map<DIST::Array*, pair<vector<ArrayOp>, vector<bool>>>& readOps = currLoop->readOps;
map<DIST::Array*, vector<ArrayOp>>& remoteReads = currLoop->remoteRegularReads;
2023-09-14 19:43:13 +03:00
Statement* loop = currLoop->loop;
string directive = "";
vector<Expression*> dirStatement = { NULL, NULL, NULL };
SgForStmt* loopG = (SgForStmt*)loop->GetOriginal();
const set<string> usedInLoop = fillUsedSymbols(loopG);
vector<SgStatement*> moduleList;
findModulesInFile(file, moduleList);
2024-11-21 15:07:16 +03:00
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
2023-09-14 19:43:13 +03:00
SgStatement* parentFunc = getFuncStat(realStat);
const int nested = countPerfectLoopNest(loopG);
vector<SgSymbol*> loopSymbs;
vector<LoopGraph*> loops;
LoopGraph* pLoop = currLoop;
const set<string> allFiles = getAllFilesInProject();
map<string, DIST::Array*> arrayByName;
for (DIST::Array* arr : currLoop->getAllArraysInLoop())
arrayByName[arr->GetName()] = arr;
2023-09-14 19:43:13 +03:00
for (int z = 0; z < nested; ++z)
{
loopSymbs.push_back(loopG->symbol());
auto next = loopG->lexNext();
auto attrSpfPar = getAttributes<SgStatement*, SgStatement*>(next, set<int>{ SPF_PARAMETER_OP });
while (attrSpfPar.size() != 0 && next)
{
next = next->lexNext();
attrSpfPar = getAttributes<SgStatement*, SgStatement*>(next, set<int>{ SPF_PARAMETER_OP });
}
if (next->variant() != FOR_NODE && z + 1 < nested)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
loopG = (SgForStmt*)next;
loops.push_back(pLoop);
if (pLoop->children.size())
pLoop = pLoop->children[0];
}
SgExpression* expr = new SgExpression(EXPR_LIST);
SgExpression* p = expr;
directive += "!DVM$ PARALLEL(";
//filter parallel
vector<string> filteredParalel;
for (int i = 0; i < (int)parallel.size(); ++i)
if (parallel[i] != "*")
filteredParalel.push_back(parallel[i]);
set<SgSymbol*> privatesAfterSwap = changeLoopOrder(parallel, filteredParalel, loops);
for (int i = 0; i < (int)filteredParalel.size(); ++i)
{
if (filteredParalel[i] == "*")
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (i == 0)
directive += filteredParalel[i];
else
directive += "," + filteredParalel[i];
SgVarRefExp* tmp = NULL;
tmp = new SgVarRefExp(findSymbolOrCreate(file, filteredParalel[i]));
p->setLhs(tmp);
if (i != (int)filteredParalel.size() - 1)
p = createAndSetNext(RIGHT, EXPR_LIST, p);
else
p->setRhs(NULL);
}
DIST::Array* mapTo;
2023-09-14 19:43:13 +03:00
dirStatement[2] = new Expression(expr);
if (sharedMemoryParallelization)
{
2023-09-14 19:43:13 +03:00
directive += ")";
}
2023-09-14 19:43:13 +03:00
else
{
mapTo = arrayRef2->IsLoopArray() ? arrayRef : arrayRef2;
2023-09-14 19:43:13 +03:00
directive += ") ON " + mapTo->GetShortName() + "(";
}
2023-09-14 19:43:13 +03:00
SgArrayRefExp* arrayExpr = NULL;
string arrayExprS = "";
if (!sharedMemoryParallelization)
2023-09-14 19:43:13 +03:00
{
auto onTo = arrayRef2->IsLoopArray() ? on : on2;
2023-09-14 19:43:13 +03:00
SgSymbol* symbForPar = NULL;
if (arrayRef->IsTemplate())
{
if (mapTo->IsLoopArray())
2025-02-18 13:45:20 +03:00
symbForPar = findSymbolOrCreate(file, mapTo->GetShortName(), new SgArrayType(*SgTypeInt()), file->GetOriginal()->firstStatement());
2023-09-14 19:43:13 +03:00
else
2025-02-18 13:45:20 +03:00
{
symbForPar = (SgSymbol*)mapTo->GetNameInLocationS(parentFunc);
}
2023-09-14 19:43:13 +03:00
}
else
2025-02-18 13:45:20 +03:00
symbForPar = (SgSymbol*)arrayRef->GetNameInLocationS(parentFunc);
2023-09-14 19:43:13 +03:00
arrayExpr = new SgArrayRefExp(*symbForPar);
arrayExprS = "";
for (int i = 0; i < (int)onTo.size(); ++i)
{
const pair<int, int>& coeffs = onTo[i].second;
assert((coeffs.first != 0 && onTo[i].first != "*") || onTo[i].first == "*");
if (i != 0)
arrayExprS += ",";
if (onTo[i].first == "*")
{
arrayExprS += "*";
SgVarRefExp* varExpr = new SgVarRefExp(findSymbolOrCreate(file, "*"));
arrayExpr->addSubscript(*varExpr);
}
else
{
arrayExprS += genStringExpr(onTo[i].first, coeffs);
arrayExpr->addSubscript(*genSgExpr(file, onTo[i].first, coeffs));
}
}
directive += arrayExprS + ")";
dirStatement[0] = new Expression(arrayExpr);
}
expr = new SgExpression(EXPR_LIST);
p = expr;
dirStatement[1] = NULL;
set<string> uniqNamesOfPrivates;
for (auto& elem : privates)
uniqNamesOfPrivates.insert(elem->identifier());
auto unitedPrivates = privates;
for (auto& elem : privatesAfterSwap)
{
if (uniqNamesOfPrivates.find(elem->identifier()) == uniqNamesOfPrivates.end())
{
unitedPrivates.insert(new Symbol(elem));
uniqNamesOfPrivates.insert(elem->identifier());
}
}
if (unitedPrivates.size() != 0)
{
p = createAndSetNext(LEFT, ACC_PRIVATE_OP, p);
directive += ", PRIVATE(";
int k = 0;
vector<SgExpression*> list;
auto spfParVars = fillPrivateOnlyFromSpfParameter(loop, currLoop->lineNum < 0 ? currLoop->altLineNum : 0);
for (auto& privVar : setToMapWithSortByStr(unitedPrivates))
{
bool isSfpPriv = false;
for (auto& elem : spfParVars)
if (OriginalSymbol(elem)->identifier() == string(OriginalSymbol(privVar.second)->identifier()))
isSfpPriv = true;
if (isSfpPriv)
continue;
directive += (k != 0) ? "," + privVar.first : privVar.first;
2025-02-18 18:57:05 +03:00
list.push_back(new SgVarRefExp(getNameInLocation(privVar.second, parentFunc)));
2023-09-14 19:43:13 +03:00
++k;
}
directive += ")";
dirStatement[1] = new Expression(expr);
p->setLhs(makeExprList(list));
}
if (sharedMemoryParallelization || (across.size() != 0 && !arrayRef2->IsLoopArray()))
2023-09-14 19:43:13 +03:00
{
vector<LoopGraph*> loopsTie;
for (int i = 0; i < (int)parallel.size(); ++i)
if (parallel[i] != "*")
loopsTie.push_back(loops[i]);
2023-09-14 19:43:13 +03:00
set<DIST::Array*> onlyFor;
if (sharedMemoryParallelization == 0 && across.size())
{
for (int k = 0; k < (int)across.size(); ++k)
2023-09-14 19:43:13 +03:00
{
DIST::Array* currArray = allArrays.GetArrayByName(across[k].first.second);
if (currArray != mapTo)
onlyFor.insert(currArray);
2023-09-14 19:43:13 +03:00
}
}
vector<SgExpression*> tieList;
if (sharedMemoryParallelization)
2025-02-18 18:57:05 +03:00
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
else if (onlyFor.size()) // not MPI regime
2025-02-18 18:57:05 +03:00
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, file, parentFunc, onlyFor, uniqNamesOfPrivates);
2023-09-14 19:43:13 +03:00
if (tieList.size())
{
if (dirStatement[1] != NULL)
2023-09-14 19:43:13 +03:00
{
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
p = expr;
}
p = createAndSetNext(LEFT, ACC_TIE_OP, p);
p->setLhs(makeExprList(tieList));
2023-09-14 19:43:13 +03:00
directive += ", TIE(";
int k = 0;
for (auto& tieL : tieList)
{
if (k != 0)
directive += ",";
directive += tieL->unparse();
++k;
2023-09-14 19:43:13 +03:00
}
directive += ")";
2023-09-14 19:43:13 +03:00
}
}
set<DIST::Array*> arraysInAcross;
if (across.size() != 0)
{
if (acrossShifts.size() == 0)
{
acrossShifts.resize(across.size());
for (int i = 0; i < across.size(); ++i)
acrossShifts[i].resize(across[i].second.size());
}
//TODO: add "OUT" key for string representation
string acrossAdd = ", ACROSS(";
int inserted = 0;
SgExpression* acr_out = new SgExpression(EXPR_LIST);
SgExpression* p_out = acr_out;
SgExpression* acr_in = new SgExpression(EXPR_LIST);
SgExpression* p_in = acr_in;
SgExpression* acr_op = NULL;
int inCount = 0;
int outCount = 0;
vector<int> ordered = sortShadow(across);
for (int k = 0; k < (int)across.size(); ++k)
{
const int i1 = ordered[k];
vector<map<pair<int, int>, int>> shiftsByAccess;
DIST::Array* acrossArray = NULL;
2023-09-14 19:43:13 +03:00
if (!sharedMemoryParallelization)
{
acrossArray = allArrays.GetArrayByName(across[i1].first.second);
if (acrossArray == NULL)
{
//TODO: need to fix SageDep analysis or use IR
bool notPrivate = true;
for (auto& arrayPair : declaredArrays)
{
auto array = arrayPair.second.first;
if (array->GetName() == across[i1].first.second)
{
if (array->IsNotDistribute())
notPrivate = false;
break;
}
}
if (notPrivate)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
else
continue;
}
}
else
{
auto currArray_it = arrayByName.find(across[i1].first.second);
if (currArray_it == arrayByName.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
acrossArray = currArray_it->second;
}
2023-09-14 19:43:13 +03:00
bool isOut = acrossOutAttribute.find(acrossArray) != acrossOutAttribute.end();
string bounds = genBounds(across[i1], acrossShifts[i1], reducedG, allArrays, acrossArray, remoteReads, readOps, true, regionId, distribution, arraysInAcross, shiftsByAccess, arrayLinksByFuncCalls);
2023-09-14 19:43:13 +03:00
if (bounds != "")
{
if (inserted != 0)
{
acrossAdd += ",";
if (isOut)
{
if (outCount > 0)
p_out = createAndSetNext(RIGHT, EXPR_LIST, p_out);
outCount++;
p = p_out;
}
else
{
if (inCount > 0)
p_in = createAndSetNext(RIGHT, EXPR_LIST, p_in);
inCount++;
p = p_in;
}
}
else if (inserted == 0)
{
if (dirStatement[1] != NULL)
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
acr_op = createAndSetNext(LEFT, ACROSS_OP, expr);
if (isOut)
{
outCount++;
p = p_out;
}
else
{
inCount++;
p = p_in;
}
}
acrossAdd += across[i1].first.first + "(" + bounds + ")";
2025-02-18 13:45:20 +03:00
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*((SgSymbol*)acrossArray->GetNameInLocationS(parentFunc)));
newArrayRef->addAttribute(ARRAY_REF, acrossArray, sizeof(DIST::Array));
2023-09-14 19:43:13 +03:00
for (auto& elem : genSubscripts(across[i1].second, acrossShifts[i1]))
newArrayRef->addSubscript(*elem);
p->setLhs(newArrayRef);
inserted++;
}
}
acrossAdd += ")";
if (inserted > 0)
{
directive += acrossAdd;
if (dirStatement[1] == NULL)
dirStatement[1] = new Expression(expr);
if (acrossOutAttribute.size() > 0)
{
SgExpression* tmp = new SgExpression(DDOT, new SgKeywordValExp("OUT"), acr_out, NULL);
acr_op->setLhs(*tmp);
if (inCount != 0)
acr_op->setRhs(acr_in);
}
else
acr_op->setLhs(acr_in);
}
}
if (shadowRenew.size() != 0 && sharedMemoryParallelization == 0)
2023-09-14 19:43:13 +03:00
{
if (shadowRenewShifts.size() == 0)
{
shadowRenewShifts.resize(shadowRenew.size());
for (int i = 0; i < shadowRenew.size(); ++i)
shadowRenewShifts[i].resize(shadowRenew[i].second.size());
}
string shadowAdd = ", SHADOW_RENEW(";
int inserted = 0;
vector<int> ordered = sortShadow(shadowRenew);
for (int k = 0; k < (int)shadowRenew.size(); ++k)
{
const int i1 = ordered[k];
vector<map<pair<int, int>, int>> shiftsByAccess;
DIST::Array* shadowArray = allArrays.GetArrayByName(shadowRenew[i1].first.second);
if (shadowArray == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
const string bounds = genBounds(shadowRenew[i1], shadowRenewShifts[i1], reducedG, allArrays, shadowArray, remoteReads, readOps, false, regionId, distribution, arraysInAcross, shiftsByAccess, arrayLinksByFuncCalls);
2023-09-14 19:43:13 +03:00
if (bounds != "")
{
DIST::Array* shadowArray = allArrays.GetArrayByName(shadowRenew[i1].first.second);
2023-09-14 19:43:13 +03:00
if (inserted != 0)
{
shadowAdd += ",";
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
else if (inserted == 0)
{
if (dirStatement[1] != NULL)
{
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
p = expr;
}
p = createAndSetNext(LEFT, SHADOW_RENEW_OP, p);
p = createAndSetNext(LEFT, EXPR_LIST, p);
}
shadowAdd += shadowRenew[i1].first.first + "(" + bounds + ")";
2025-02-18 13:45:20 +03:00
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*((SgSymbol*)shadowArray->GetNameInLocationS(parentFunc)));
newArrayRef->addAttribute(ARRAY_REF, shadowArray, sizeof(DIST::Array));
2023-09-14 19:43:13 +03:00
for (auto& elem : genSubscripts(shadowRenew[i1].second, shadowRenewShifts[i1]))
newArrayRef->addSubscript(*elem);
if (shadowRenew[i1].second.size() > 1 && needCorner(shadowArray, shiftsByAccess, loop))
2023-09-14 19:43:13 +03:00
{
SgExpression* tmp = new SgExpression(ARRAY_OP, newArrayRef, NULL, NULL);
p->setLhs(*tmp);
shadowAdd += "(CORNER)";
SgKeywordValExp* tmp1 = new SgKeywordValExp("CORNER");
p->lhs()->setRhs(tmp1);
}
else
p->setLhs(newArrayRef);
inserted++;
}
}
shadowAdd += ")";
if (inserted > 0)
{
directive += shadowAdd;
if (dirStatement[1] == NULL)
dirStatement[1] = new Expression(expr);
}
}
if (reduction.size() != 0)
{
if (dirStatement[1] != NULL)
{
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
p = expr;
}
p = createAndSetNext(LEFT, REDUCTION_OP, p);
p = createAndSetNext(LEFT, EXPR_LIST, p);
directive += ", REDUCTION(";
int k = 0;
for (auto it = reduction.begin(); it != reduction.end(); ++it)
{
const string& nameGroup = it->first;
2025-02-18 18:57:05 +03:00
for (auto& red : it->second)
2023-09-14 19:43:13 +03:00
{
if (k != 0)
{
directive += ",";
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
2025-02-18 18:57:05 +03:00
SgSymbol* redS;
string clearName = correctSymbolModuleName(red);
if (clearName != red)
2025-05-11 09:17:16 +03:00
redS = getNameInLocation(parentFunc, red, getModuleScope(red, moduleList, parentFunc)->symbol()->identifier());
2025-02-18 18:57:05 +03:00
else
redS = findSymbolOrCreate(file, clearName, NULL, parentFunc);
2023-09-14 19:43:13 +03:00
directive += nameGroup + "(" + redS->identifier() + ")";
SgVarRefExp* tmp2 = new SgVarRefExp(redS);
SgFunctionCallExp* tmp1 = new SgFunctionCallExp(*findSymbolOrCreate(file, nameGroup), *tmp2);
p->setLhs(tmp1);
++k;
}
}
if (reductionLoc.size() != 0)
directive += ", ";
else
{
directive += ")";
if (dirStatement[1] == NULL)
dirStatement[1] = new Expression(expr);
}
}
if (reductionLoc.size() != 0)
{
if (dirStatement[1] != NULL && reduction.size() == 0)
{
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
p = expr;
}
if (reduction.size() == 0)
{
p = createAndSetNext(LEFT, REDUCTION_OP, p);
p = createAndSetNext(LEFT, EXPR_LIST, p);
directive += ", REDUCTION(";
}
else
p = createAndSetNext(RIGHT, EXPR_LIST, p);
int k = 0;
for (auto it = reductionLoc.begin(); it != reductionLoc.end(); ++it)
{
const string& nameGroup = it->first;
for (auto& list : it->second)
{
if (k != 0)
{
directive += ",";
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
2025-02-18 18:57:05 +03:00
SgSymbol *redS1, *redS2;
string clearName1 = correctSymbolModuleName(get<0>(list));
string clearName2 = correctSymbolModuleName(get<1>(list));
2023-09-14 19:43:13 +03:00
2025-02-18 18:57:05 +03:00
if (clearName1 != get<0>(list))
2025-05-11 09:17:16 +03:00
redS1 = getNameInLocation(parentFunc, get<0>(list), getModuleScope(get<0>(list), moduleList, parentFunc)->symbol()->identifier());
2025-02-18 18:57:05 +03:00
else
redS1 = findSymbolOrCreate(file, clearName1, NULL, parentFunc);
if (clearName2 != get<1>(list))
2025-05-11 09:17:16 +03:00
redS2 = getNameInLocation(parentFunc, get<1>(list), getModuleScope(get<1>(list), moduleList, parentFunc)->symbol()->identifier());
2025-02-18 18:57:05 +03:00
else
redS2 = findSymbolOrCreate(file, clearName2, NULL, parentFunc);
2023-09-14 19:43:13 +03:00
directive += nameGroup + "(" + redS1->identifier() + ", " + redS2->identifier() + ", " + std::to_string(get<2>(list)) + ")";
SgFunctionCallExp* tmp1 = new SgFunctionCallExp(*findSymbolOrCreate(file, nameGroup));
tmp1->addArg(*new SgVarRefExp(redS1));
tmp1->addArg(*new SgVarRefExp(redS2));
tmp1->addArg(*new SgValueExp(get<2>(list)));
p->setLhs(tmp1);
++k;
}
}
directive += ")";
if (dirStatement[1] == NULL)
dirStatement[1] = new Expression(expr);
}
if (remoteAccess.size() != 0 && sharedMemoryParallelization == 0)
2023-09-14 19:43:13 +03:00
{
if (dirStatement[1] != NULL)
{
expr = createAndSetNext(RIGHT, EXPR_LIST, expr);
p = expr;
}
p = createAndSetNext(LEFT, REMOTE_ACCESS_OP, p);
p = createAndSetNext(LEFT, EXPR_LIST, p);
directive += ", REMOTE_ACCESS(";
int k = 0;
for (auto it = remoteAccess.begin(); it != remoteAccess.end(); ++it, ++k)
{
directive += it->first.first.first + "(";
directive += it->first.second + ")";
DIST::Array* currArray = allArrays.GetArrayByName(it->first.first.second);
2025-02-18 13:45:20 +03:00
SgArrayRefExp* tmp = new SgArrayRefExp(*((SgSymbol*)currArray->GetNameInLocationS(parentFunc)), *it->second);
2023-09-14 19:43:13 +03:00
tmp->addAttribute(ARRAY_REF, currArray, sizeof(DIST::Array));
p->setLhs(tmp);
if (k != remoteAccess.size() - 1)
{
directive += ",";
p = createAndSetNext(RIGHT, EXPR_LIST, p);
}
}
directive += ")";
if (dirStatement[1] == NULL)
dirStatement[1] = new Expression(expr);
}
directive += "\n";
auto dir = new CreatedDirective(directive, dirStatement);
dir->line = currLoop->lineNum;
return dir;
}
void DistrVariant::GenRule(File *file, Expression *rule, const vector<int> &newOrder) const
{
for (int i = 0; i < distRule.size(); ++i)
{
SgVarRefExp *toSet = NULL;
if (newOrder.size() == 0)
{
if (distRule[i] == dist::NONE)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "*"));
rule->setLhs(toSet);
}
else if (distRule[i] == dist::BLOCK)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "BLOCK"));
rule->setLhs(toSet);
}
}
else
{
if (distRule[newOrder[i]] == dist::NONE)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "*"));
rule->setLhs(toSet);
}
else if (distRule[newOrder[i]] == dist::BLOCK)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "BLOCK"));
rule->setLhs(toSet);
}
}
if (i != distRule.size() - 1)
{
SgExpression *list = new SgExpression(EXPR_LIST);
rule->setRhs(list);
rule = new Expression(rule->rhs());
}
}
}
vector<Expression*> DistrVariant::GenRuleSt(File *file, const vector<int> &newOrder) const
{
vector<Expression*> retVal;
for (int i = 0; i < distRule.size(); ++i)
{
SgVarRefExp *toSet = NULL;
if (newOrder.size() == 0)
{
if (distRule[i] == dist::NONE)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "*"));
retVal.push_back(new Expression(toSet));
}
else if (distRule[i] == dist::BLOCK)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "BLOCK"));
retVal.push_back(new Expression(toSet));
}
}
else
{
if (distRule[newOrder[i]] == dist::NONE)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "*"));
retVal.push_back(new Expression(toSet));
}
else if (distRule[newOrder[i]] == dist::BLOCK)
{
toSet = new SgVarRefExp(findSymbolOrCreate(file, "BLOCK"));
retVal.push_back(new Expression(toSet));
}
}
}
return retVal;
}
vector<Statement*> DataDirective::GenRule(File *file, const vector<int> &rules, const int variant) const
{
vector<Statement*> retVal;
if (distrRules.size() < rules.size())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (int i = 0; i < rules.size(); ++i)
{
if (rules[i] < distrRules[i].second.size())
{
string tmp = distrRules[i].first->GetShortName();
SgStatement *dir = new SgStatement(variant, NULL, NULL, NULL, NULL, NULL);
SgVarRefExp *dirstRef = new SgVarRefExp(*findSymbolOrCreate(file, tmp));
SgExpression *rule = new SgExpression(EXPR_LIST);
distrRules[i].second[rules[i]].GenRule(file, new Expression(rule), distrRules[i].first->GetNewTemplateDimsOrder());
SgExpression *toAdd = new SgExpression(EXPR_LIST, dirstRef, NULL, NULL);
dir->setExpression(0, *toAdd);
dir->setExpression(1, *rule);
retVal.push_back(new Statement(dir));
}
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
return retVal;
}
vector<Statement*> DataDirective::GenAlignsRules(File *file, const int variant) const
{
vector<Statement*> retVal;
for (int i = 0; i < alignRules.size(); ++i)
{
Statement *newRule = alignRules[i].GenRule(file, variant);
if (newRule)
retVal.push_back(newRule);
}
return retVal;
}
Statement* AlignRule::GenRule(File *file, const int variant) const
{
// local and realign
if (alignArray->GetLocation().first == 0 && variant == DVM_REALIGN_DIR)
return NULL;
SgStatement *retVal = new SgStatement(variant, NULL, NULL, NULL, NULL, NULL);
SgVarRefExp *alignRef = new SgVarRefExp(findSymbolOrCreate(file, alignArray->GetShortName()));
SgExpression *list = new SgExpression(EXPR_LIST, alignRef, NULL, NULL);
retVal->setExpression(0, *list);
SgExpression *alignList = new SgExpression(EXPR_LIST);
retVal->setExpression(1, *alignList);
for (int i = 0; i < alignRule.size(); ++i)
{
alignList->setLhs(genSgExpr(file, alignNames[i], alignRule[i]));
if (i != alignRule.size() - 1)
{
list = new SgExpression(EXPR_LIST);
alignList->setRhs(list);
alignList = alignList->rhs();
}
}
SgSymbol *sAlignWith = &(findSymbolOrCreate(file, alignWith->GetShortName())->copy());
SgArrayType *arrayType = new SgArrayType(*SgTypeInt());
sAlignWith->setType(arrayType);
SgArrayRefExp *alignWithRef = new SgArrayRefExp(*sAlignWith);
vector<SgExpression*> alignEachDim(alignWith->GetDimSize());
for (int i = 0; i < alignWith->GetDimSize(); ++i)
alignEachDim[i] = new SgVarRefExp(findSymbolOrCreate(file, "*"));
for (int i = 0; i < alignRuleWith.size(); ++i)
if (alignRuleWith[i].first != -1)
alignEachDim[alignRuleWith[i].first] = genSgExpr(file, alignNames[i], alignRuleWith[i].second);
auto newOrder = alignWith->GetNewTemplateDimsOrder();
if (newOrder.size() != 0)
{
vector<SgExpression*> alignEachDimNew(alignEachDim);
for (int i = 0; i < newOrder.size(); ++i)
alignEachDim[i] = alignEachDimNew[newOrder[i]];
}
for (int i = 0; i < alignWith->GetDimSize(); ++i)
alignWithRef->addSubscript(*alignEachDim[i]);
retVal->setExpression(2, *alignWithRef);
return new Statement(retVal);
}
pair<SgExpression*, SgExpression*> genShadowSpec(SgFile *file, const pair<string, const vector<pair<int, int>>> &shadowSpecs)
{
pair<SgExpression*, SgExpression*> result;
SgVarRefExp *tmp = new SgVarRefExp(findSymbolOrCreate(file, shadowSpecs.first));
result.first = new SgExpression(EXPR_LIST, tmp, NULL, NULL, NULL);
SgExpression *listEx = new SgExpression(EXPR_LIST);
result.second = listEx;
bool needInsert = false;
for (int k = 0; k < shadowSpecs.second.size(); ++k)
{
const int leftVal = shadowSpecs.second[k].first;
const int rightVal = shadowSpecs.second[k].second;
SgValueExp *tmp1 = new SgValueExp(leftVal);
SgValueExp *tmp2 = new SgValueExp(rightVal);
SgExpression *currDim = new SgExpression(DDOT, tmp1, tmp2, NULL);
listEx->setLhs(currDim);
if (shadowSpecs.second[k].first != 0 || shadowSpecs.second[k].second != 0)
needInsert = true;
if (k != shadowSpecs.second.size() - 1)
{
SgExpression *tmp = new SgExpression(EXPR_LIST);
listEx->setRhs(tmp);
listEx = listEx->rhs();
}
}
if (needInsert)
return result;
else
return make_pair<SgExpression*, SgExpression*>(NULL, NULL);
}
//TODO: check this
void correctShadowSpec(SgExpression *listEx, const vector<pair<int, int>> &shadowSpecs)
{
for (int k = 0; k < shadowSpecs.size(); ++k)
{
const int leftVal = shadowSpecs[k].first;
const int rightVal = shadowSpecs[k].second;
if (listEx)
{
if (listEx->lhs())
{
if (listEx->lhs()->lhs())
if (listEx->lhs()->lhs()->valueInteger() < leftVal)
((SgValueExp*)(listEx->lhs()->lhs()))->setValue(leftVal);
if (listEx->lhs()->rhs())
if (listEx->lhs()->rhs()->valueInteger() < rightVal)
((SgValueExp*)(listEx->lhs()->rhs()))->setValue(rightVal);
}
}
else
break;
listEx = listEx->rhs();
}
}