Files
SAPFOR/Sapfor/_src/DirectiveProcessing/remote_access.cpp
2025-03-25 20:39:29 +03:00

1031 lines
35 KiB
C++

#include "../Utils/leak_detector.h"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdint>
#include <string>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <utility>
#include <assert.h>
#include "dvm.h"
#include "../LoopAnalyzer/loop_analyzer.h"
#include "../Utils/types.h"
#include "../Utils/errors.h"
#include "../Utils/SgUtils.h"
#include "../Distribution/Arrays.h"
#include "../GraphCall/graph_calls.h"
#include "../GraphCall/graph_calls_func.h"
#include "../GraphLoop/graph_loops_func.h"
#include "remote_access.h"
using std::vector;
using std::pair;
using std::tuple;
using std::map;
using std::set;
using std::make_pair;
using std::make_tuple;
using std::get;
using std::string;
using std::wstring;
void fillRemoteFromDir(SgExpression* ex, vector<SgExpression*>& remotes)
{
if (ex)
{
if (ex->variant() == ARRAY_REF)
remotes.push_back(ex);
else
{
fillRemoteFromDir(ex->lhs(), remotes);
fillRemoteFromDir(ex->rhs(), remotes);
}
}
}
SgExpression* remoteAggregation(SgExpression* remote, const vector<SgExpression*>* newRemotes)
{
vector<SgExpression*> remotes;
fillRemoteFromDir(remote, remotes);
if (newRemotes)
for (auto elem : *newRemotes)
fillRemoteFromDir(elem, remotes);
if (remotes.size() == 0)
return NULL;
map<string, vector<SgExpression*>> remByName;
for (auto& rem : remotes)
remByName[rem->symbol()->identifier()].push_back(rem);
set<string> existsRemotes;
vector<SgExpression*> list;
for (auto& byName : remByName)
{
SgExpression* fullDot = NULL;
for (auto& rem : byName.second)
{
SgExpression* list = rem->lhs();
int countDim = 0;
int countDD = 0;
while (list)
{
countDim++;
if (list->lhs()->variant() == DDOT)
countDD++;
list = list->rhs();
}
if (countDim == countDD && countDD != 0)
fullDot = rem;
if (fullDot)
break;
}
if (fullDot)
list.push_back(fullDot);
else
{
//TODO: group by the same parts, eg. A(:,N,:) + A(1,N,:) -> A(:,N,:)
for (auto& rem : byName.second)
{
const string curr = rem->unparse();
auto exist = existsRemotes.find(curr);
if (exist == existsRemotes.end())
{
existsRemotes.insert(exist, curr);
list.push_back(rem);
}
}
}
}
if (isSgExprListExp(remote))
remote = makeExprList(list);
else
remote->setLhs(makeExprList(list));
return remote;
}
static DIST::Array* GetArrayByShortName(const DIST::Arrays<int> &allArrays, SgSymbol *name)
{
auto uniqKey = getFromUniqTable(name);
string nameFromUniq = getShortName(uniqKey);
return allArrays.GetArrayByName(nameFromUniq);
}
static SgExpression* getArrayRefCopyFromAttribute(SgExpression* ex)
{
SgExpression* copy = ex;
if (isArrayRef(ex))
{
for (int i = 0; i < ex->numberOfAttributes(); ++i)
if (ex->attributeType(i) == ARRAY_REF)
copy = (SgExpression*)(ex->getAttribute(i)->getAttributeData());
}
return copy;
}
static bool checkArrayRef(SgExpression* arrayRef, vector<pair<SgExpression*, SgExpression*>>& remotes, const DIST::Arrays<int>& allArrays,
const DataDirective& data, const vector<int>& currVar, const uint64_t regionID,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls)
{
bool retVal = false;
DIST::Array* currArray = GetArrayByShortName(allArrays, OriginalSymbol(arrayRef->symbol()));
set<DIST::Array*> realRefs;
getRealArrayRefs(currArray, currArray, realRefs, arrayLinksByFuncCalls);
for (auto array : realRefs)
{
if (array != NULL)
{
// find distributed dims
DIST::Array* templ = array->GetTemplateArray(regionID);
checkNull(templ, convertFileName(__FILE__).c_str(), __LINE__);
auto links = array->GetLinksWithTemplate(regionID);
bool needToAdd = false;
for (int i = 0; i < data.distrRules.size(); ++i)
{
if (data.distrRules[i].first == templ)
{
const vector<dist>& rule = data.distrRules[i].second[currVar[i]].distRule;
for (int k = 0; k < links.size(); ++k)
{
const int idx = links[k];
if (idx >= 0)
{
if (rule[idx] == BLOCK)
{
needToAdd = true;
break;
}
}
}
break;
}
}
//and add, if found any distributed dim
if (needToAdd)
{
remotes.push_back(make_pair(arrayRef->copyPtr(), getArrayRefCopyFromAttribute(arrayRef)));
retVal = true;
break;
}
}
}
return retVal;
}
static bool findAllArraysForRemote(SgStatement* st, SgExpression* expr,
vector<pair<SgExpression*, SgExpression*>>& remotes, const DIST::Arrays<int>& allArrays,
const DataDirective& data, const vector<int>& currVar, const uint64_t regionID,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls)
{
bool retVal = false;
if (expr == NULL)
return retVal;
bool isProcCall = st->variant() == PROC_STAT && st->expr(0) == expr;
if (expr->variant() == FUNC_CALL || isProcCall)
{
SgExpression* list = isProcCall ? expr : expr->lhs();
while (list)
{
auto arg = list->lhs();
if (isArrayRef(arg) && isSgArrayRefExp(arg)->numberOfSubscripts() != 0)
{
bool tmp = checkArrayRef(arg, remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
if (arg->lhs())
{
bool tmp = findAllArraysForRemote(st, arg->lhs(), remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
if (arg->rhs())
{
bool tmp = findAllArraysForRemote(st, arg->rhs(), remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
list = list->rhs();
}
return retVal;
}
else if (isArrayRef(expr))
{
bool tmp = checkArrayRef(expr, remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
if (expr->lhs())
{
bool tmp = findAllArraysForRemote(st, expr->lhs(), remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
if (expr->rhs())
{
bool tmp = findAllArraysForRemote(st, expr->rhs(), remotes, allArrays, data, currVar, regionID, arrayLinksByFuncCalls);
retVal = retVal || tmp;
}
return retVal;
}
static bool checkExpr(SgExpression *ex, const set<int>& noSimpleVars, const set<string>& writeVars)
{
bool retVal = true;
if (ex == NULL)
return true;
const int var = ex->variant();
if (noSimpleVars.find(var) != noSimpleVars.end())
{
if (var == VAR_REF)
{
if (writeVars.find(OriginalSymbol(ex->symbol())->identifier()) != writeVars.end())
return false;
else
return true;
}
return false;
}
if (ex->lhs())
retVal = retVal && checkExpr(ex->lhs(), noSimpleVars, writeVars);
if (ex->rhs())
retVal = retVal && checkExpr(ex->rhs(), noSimpleVars, writeVars);
return retVal;
}
static void checkFuncCalls(SgExpression* ex, set<string>& writeVars, const map<string, FuncInfo*>& funcMap, SgStatement* st, vector<bool>& retVal)
{
if (ex)
{
if (ex->variant() == FUNC_CALL && !isIntrinsicFunctionName(ex->symbol()->identifier()))
{
auto it = funcMap.find(ex->symbol()->identifier());
if (it == funcMap.end())
{
__spf_print(1, "can not find func '%s' in %s:%d\n", ex->symbol()->identifier(), st->fileName(), st->lineNumber());
//printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (int z = 0; z < retVal.size(); ++z)
retVal[z] = false;
return;
}
int num = 0;
for (auto fPar = ex->lhs(); fPar; fPar = fPar->rhs(), ++num)
if (fPar->lhs() && fPar->lhs()->variant() == VAR_REF && it->second->funcParams.isArgOut(num))
writeVars.insert(OriginalSymbol(fPar->lhs()->symbol())->identifier());
}
checkFuncCalls(ex->lhs(), writeVars, funcMap, st, retVal);
checkFuncCalls(ex->rhs(), writeVars, funcMap, st, retVal);
}
}
//TODO: check VAR_REF declaration in common block / module
vector<bool> isSimpleRef(SgStatement* stS, SgStatement* stE, SgExpression *subs, const set<int> noSimpleVars, const map<string, FuncInfo*>& funcMap,
const set<string>& usedVars)
{
vector<bool> retVal;
SgExpression* tmp = subs;
while (tmp)
{
retVal.push_back(true);
tmp = tmp->rhs();
}
set<string> writeVars = usedVars;
if (noSimpleVars.find(VAR_REF) != noSimpleVars.end() && stS && stE)
{
if (stE == stS)
stE = stE->lexNext();
for (auto st = stS; st != stE; st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT)
if (st->expr(0)->variant() == VAR_REF)
writeVars.insert(OriginalSymbol(st->expr(0)->symbol())->identifier());
if (st->variant() == PROC_STAT && !isIntrinsicFunctionName(st->symbol()->identifier()))
{
auto it = funcMap.find(st->symbol()->identifier());
if (it == funcMap.end())
{
__spf_print(1, "can not find proc '%s' in %s:%d\n", st->symbol()->identifier(), st->fileName(), st->lineNumber());
//printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (int z = 0; z < retVal.size(); ++z)
retVal[z] = false;
break;
}
int num = 0;
for (auto ex = st->expr(0); ex; ex = ex->rhs(), ++num)
if (ex->lhs() && ex->lhs()->variant() == VAR_REF && it->second->funcParams.isArgOut(num))
writeVars.insert(OriginalSymbol(ex->lhs()->symbol())->identifier());
}
if (st->variant() == FOR_NODE)
writeVars.insert(OriginalSymbol(isSgForStmt(st)->doName())->identifier());
for (int z = 0; z < 3; ++z)
checkFuncCalls(st->expr(z), writeVars, funcMap, st, retVal);
}
}
int z = 0;
tmp = subs;
while (tmp)
{
if (retVal[z])
retVal[z] = checkExpr(tmp->lhs(), noSimpleVars, writeVars);
++z;
tmp = tmp->rhs();
}
return retVal;
}
static bool inline hasDirs(SgStatement *st, const int var)
{
SgStatement *last = st->lastNodeOfStmt();
for ( ;st != last; st = st->lexNext())
if (st->variant() == var)
return true;
return false;
}
static bool isDistributed(SgSymbol *in)
{
SgSymbol *s = OriginalSymbol(in);
DIST::Array *decl = getArrayFromDeclarated(declaratedInStmt(s), s->identifier());
if (!decl)
return false;
else
return !(decl->IsNotDistribute());
}
//TODO: need to add IPA (functions)
static void fillRead(SgExpression *ex, SgStatement *cp, SgStatement *st,
map<string, map<string, pair<set<SgStatement*>, set<SgStatement*>>>> &readArrays)
{
if (ex)
{
if (ex->variant() == ARRAY_REF)
if (isDistributed(ex->symbol()))
{
readArrays[ex->symbol()->identifier()][string(ex->unparse())].first.insert(cp);
readArrays[ex->symbol()->identifier()][string(ex->unparse())].second.insert(st);
}
fillRead(ex->lhs(), cp, st, readArrays);
fillRead(ex->rhs(), cp, st, readArrays);
}
}
bool isNeedToConvertIfCondition(SgExpression *ex)
{
map<string, map<string, pair<set<SgStatement*>, set<SgStatement*>>>> readArrays;
fillRead(ex, NULL, NULL, readArrays);
return (readArrays.size() != 0);
}
static bool inline hasGoTo(SgStatement* stIn)
{
SgStatement* last = stIn->lastNodeOfStmt();
for (auto st = stIn; st != last; st = st->lexNext())
{
if (isSgGotoStmt(st) || isSgAssignedGotoStmt(st) || isSgComputedGotoStmt(st) ||
isSgExitStmt(st) || isSgCycleStmt(st))
return true;
}
return false;
}
static bool inline hasSpecialIfCond(SgStatement* stIn)
{
if (stIn->variant() != IF_NODE)
return false;
auto ex = stIn->expr(0);
if (!ex)
return false;
if (ex->variant() == NOT_OP)
{
ex = ex->lhs();
if (ex->variant() == VAR_REF && !ex->lhs() && !ex->rhs())
{
if (string(ex->symbol()->identifier()).find("spf_If_C") != string::npos)
return true;
}
}
return false;
}
static bool inline hasAssignsToArray(SgStatement *stIn)
{
// array -> unparse access -> pair [ control par, original stat]
map<string, map<string, pair<set<SgStatement*>, set<SgStatement*>>>> arrayAccessWrite;
map<string, map<string, pair<set<SgStatement*>, set<SgStatement*>>>> arrayAccessRead;
SgStatement *last = stIn->lastNodeOfStmt();
if (stIn->variant() == IF_NODE)
{
while (last->variant() != CONTROL_END)
last = last->lastNodeOfStmt();
}
for (auto st = stIn; st != last; st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT)
{
SgExpression *ex = st->expr(0);
if (ex->variant() == ARRAY_REF)
{
SgSymbol *s = ex->symbol();
if (isDistributed(s))
{
arrayAccessWrite[s->identifier()][string(ex->unparse())].first.insert(st->controlParent());
arrayAccessWrite[s->identifier()][string(ex->unparse())].second.insert(st);
}
}
}
}
for (auto st = stIn; st != last; st = st->lexNext())
{
SgStatement *cp = st->controlParent();
if (st->variant() != ASSIGN_STAT)
{
for (int z = 0; z < 3; ++z)
fillRead(st->expr(z), cp, st, arrayAccessRead);
}
else
{
for (int z = 1; z < 3; ++z)
fillRead(st->expr(z), cp, st, arrayAccessRead);
SgExpression *left = st->expr(0);
fillRead(left->lhs(), cp, st, arrayAccessRead);
fillRead(left->rhs(), cp, st, arrayAccessRead);
}
}
for (auto &readPair : arrayAccessRead)
{
string arrayName = readPair.first;
auto it = arrayAccessWrite.find(arrayName);
//TODO:
if (it != arrayAccessWrite.end())
{
return true;
/*for (auto &read : readPair.second)
{
auto mapW = it->second.find(read.first);
if (mapW == it->second.end())
return true;
else
{
for (auto &cpW : mapW->second.first)
{
for (auto &cpR : read.second.first)
if (cpW != cpR && cpW->variant() == FOR_NODE)
return true;
}
for (auto &cpW : mapW->second.second)
if (read.second.second.find(cpW) == read.second.second.end())
return true;
}
}*/
}
}
return false;
}
static bool converToDDOT(const vector<bool>& checkResult, SgExpression *spec)
{
bool done = false;
int currIdx = 0;
while (spec)
{
if (currIdx >= checkResult.size() && checkResult.size() > 0)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (checkResult.size() == 0 || !checkResult[currIdx])
{
spec->setLhs(new SgExpression(DDOT));
done = true;
}
++currIdx;
spec = spec->rhs();
}
return done;
}
static void getAllSymbols(SgExpression *ex, set<string> &symbs)
{
if (ex)
{
if (ex->variant() == VAR_REF)
symbs.insert(ex->symbol()->identifier());
getAllSymbols(ex->lhs(), symbs);
getAllSymbols(ex->rhs(), symbs);
}
}
static bool ifRange(SgExpression* spec)
{
while (spec)
{
if (spec->lhs()->variant() == DDOT)
{
if (!spec->lhs()->lhs() || !spec->lhs()->rhs())
return true;
if (string(spec->lhs()->lhs()->unparse()) != string(spec->lhs()->rhs()->unparse()))
return true;
}
spec = spec->rhs();
}
return false;
}
static bool isAlloc(SgStatement* st)
{
return st->variant() == ALLOCATE_STMT || st->variant() == DEALLOCATE_STMT;
}
static int numOfOperators(SgIfStmt* st)
{
int num = 0;
if (st && st->falseBody() == NULL)
{
for (auto it = st->lexNext(); it != st->lastNodeOfStmt(); it = it->lexNext())
num++;
return num;
}
return num;
}
template<int NUM>
bool createRemoteDir(SgStatement *st, const map<string, FuncInfo*>& funcMap, const map<int, LoopGraph*> &sortedLoopGraph, const DIST::Arrays<int> &allArrays,
const DataDirective &data, const vector<int> &currVar, const uint64_t regionId, vector<Messages> &currMessages,
const map<DIST::Array*, set<DIST::Array*>> &arrayLinksByFuncCalls)
{
//for parallel loops after vector assign convertion
if (st->lexPrev()->variant() == DVM_PARALLEL_ON_DIR)
return false;
if (isAlloc(st))
return false;
if (st->variant() == LOGIF_NODE)
if (((SgLogIfStmt*)st)->body())
if (isAlloc(((SgLogIfStmt*)st)->body()))
return false;
vector<pair<SgExpression*, SgExpression*>> remotes; // <origRef, copyOfRefFromAttr>
string leftPartOfAssign = "";
if (st->variant() == ASSIGN_STAT)
leftPartOfAssign = getArrayRefCopyFromAttribute(st->expr(0))->unparse();
if (findAllArraysForRemote(st, st->expr(NUM), remotes, allArrays, data, currVar, regionId, arrayLinksByFuncCalls))
{
SgStatement *remoteDir = new SgStatement(DVM_REMOTE_ACCESS_DIR);
SgExpression *exprList = new SgExpression(EXPR_LIST);
remoteDir->setExpression(0, *exprList);
//exclude left part of assign: A(i,j,k) = A(i,j,k) + 5
if (leftPartOfAssign != "")
{
int z = 0;
while (z != remotes.size())
{
if (leftPartOfAssign == string(remotes[z].second->unparse()))
remotes.erase(remotes.begin() + z);
else
z++;
}
}
vector<SgExpression*> allSubs;
for (auto &rem : remotes)
allSubs.push_back(rem.first->lhs());
if (remotes.size() > 0)
{
//TODO: use CFG and RD analysis
set<string> usedSymbols;
for (auto &access : remotes)
getAllSymbols(access.first, usedSymbols);
SgStatement *toInsert = st;
vector<SgStatement*> allToInsert = { toInsert };
int lvlUp = 0;
bool througthForNode = false;
//find the uppest control parent
do
{
SgStatement *parent = toInsert->controlParent();
const int var = parent->variant();
if (var == FUNC_HEDR || var == PROC_HEDR || var == PROG_HEDR ||
hasDirs(parent, DVM_PARALLEL_ON_DIR) ||
hasDirs(parent, DVM_REMOTE_ACCESS_DIR) ||
hasAssignsToArray(parent) ||
hasSpecialIfCond(parent) ||
hasGoTo(parent))
{
break;
}
toInsert = parent;
allToInsert.push_back(toInsert);
if (toInsert->variant() == FOR_NODE)
througthForNode = true;
++lvlUp;
} while (1);
for (int idx = allToInsert.size() - 1; idx >= 0; --idx)
{
const int var = allToInsert[idx]->variant();
if (var == ELSEIF_NODE)
{
if (idx != 0)
toInsert = allToInsert[idx - 1];
else
break;
}
else
break;
--lvlUp;
}
if (toInsert->variant() == FOR_NODE || toInsert->variant() == WHILE_NODE)
{
for (auto &elem : allSubs)
{
const vector<bool> checkRes = isSimpleRef(toInsert, toInsert->lastNodeOfStmt(), elem, { ARRAY_OP, ARRAY_REF, VAR_REF }, funcMap, set<string>());
converToDDOT(checkRes, elem);
}
}
else
{
const int cpV = toInsert->controlParent()->variant();
const int varI = toInsert->variant();
for (auto& elem : allSubs)
{
set<int> varsToCheck = { ARRAY_OP, ARRAY_REF };
if (througthForNode)
varsToCheck.insert(VAR_REF);
const vector<bool> checkRes = isSimpleRef(toInsert, toInsert->lastNodeOfStmt(), elem, varsToCheck, funcMap, set<string>());
bool isSimple = true;
for (auto res : checkRes)
isSimple &= res;
if (ifRange(elem) || !isSimple)
converToDDOT(checkRes, elem);
}
}
//create remote dir with uniq expressions
set<string> exist;
int add = 0;
for (int z = 0; z < remotes.size(); ++z, ++add)
{
string currRem = remotes[z].first->unparse();
auto itR = exist.find(currRem);
if (itR == exist.end())
{
if (add != 0)
{
exprList->setRhs(new SgExpression(EXPR_LIST));
exprList = exprList->rhs();
}
exprList->setLhs(remotes[z].first);
exist.insert(itR, currRem);
}
}
exist.clear();
SgStatement *prev = toInsert->lexPrev();
if (prev)
{
if (prev->variant() != DVM_REMOTE_ACCESS_DIR)
{
remoteDir->setExpression(0, remoteAggregation(remoteDir->expr(0), NULL));
toInsert->insertStmtBefore(*remoteDir, *toInsert->controlParent());
}
else // aggregate this
{
vector<SgExpression*> tmpR;
for (auto& rem : remotes)
tmpR.push_back(rem.first);
prev->setExpression(0, remoteAggregation(prev->expr(0), &tmpR));
}
}
else
{
remoteDir->setExpression(0, remoteAggregation(remoteDir->expr(0), NULL));
toInsert->insertStmtBefore(*remoteDir, *toInsert->controlParent());
}
}
return true;
}
return false;
}
template bool createRemoteDir<0>(SgStatement*, const map<string, FuncInfo*>&, const map<int, LoopGraph*>&, const DIST::Arrays<int>&, const DataDirective&, const vector<int>&, const uint64_t, vector<Messages>&, const map<DIST::Array*, set<DIST::Array*>>&);
template bool createRemoteDir<1>(SgStatement*, const map<string, FuncInfo*>&, const map<int, LoopGraph*>&, const DIST::Arrays<int>&, const DataDirective&, const vector<int>&, const uint64_t, vector<Messages>&, const map<DIST::Array*, set<DIST::Array*>>&);
void addRemoteLink(const LoopGraph* loop, const map<string, FuncInfo*>& funcMap, ArrayRefExp *expr, map<string, ArrayRefExp*> &uniqRemotes,
const set<string>& remotesInParallel, set<ArrayRefExp*> &addedRemotes, const vector<string>& mapToLoop,
vector<Messages> &messages, const int line, bool bindToLoopDistribution)
{
SgArrayRefExp* copyExpr = NULL;
bool isConv = false;
if (bindToLoopDistribution)
{
const LoopGraph* withDir = loop;
while (withDir && withDir->loop->GetOriginal()->lexPrev()->variant() != DVM_PARALLEL_ON_DIR)
withDir = withDir->parent;
checkNull(withDir, convertFileName(__FILE__).c_str(), __LINE__);
set<string> loopVars;
for (auto& elem : withDir->directive->parallel)
if (elem != "*")
loopVars.insert(elem);
copyExpr = (SgArrayRefExp*)(expr->copyPtr());
SgExpression* subs = copyExpr->subscripts();
set<string> tmp;
checkNull(loop->loop, convertFileName(__FILE__).c_str(), __LINE__);
//get uppest nested loop
const LoopGraph* UppestNested = loop;
while (UppestNested->parent && UppestNested->parent->children.size() == 1 &&
UppestNested->loop->GetOriginal()->lexPrev()->variant() != DVM_PARALLEL_ON_DIR)
UppestNested = UppestNested->parent;
vector<bool> isSimple = isSimpleRef(UppestNested->loop->GetOriginal(), UppestNested->loop->GetOriginal()->lastNodeOfStmt(), subs, {ARRAY_OP, ARRAY_REF, VAR_REF}, funcMap, loopVars);
SgExpression* ex = subs;
for (int z = 0; z < isSimple.size(); ++z, ex = ex->rhs())
{
if (!isSimple[z] && mapToLoop.size() && mapToLoop[z] != "")
{ // check for A*x + B
SgExpression* list = ex->lhs();
if (!list->lhs() && !list->rhs()) // I
{
if (list->variant() == VAR_REF && list->symbol() && list->symbol()->identifier() == mapToLoop[z])
isSimple[z] = true;
}
else if (list->variant() == ADD_OP) // I + B or B + I
{
if (list->lhs() && list->rhs())
{
SgExpression* left = list->lhs();
SgExpression* right = list->rhs();
if (right->variant() == VAR_REF && left->variant() != VAR_REF)
std::swap(left, right);
if (left->variant() == VAR_REF && right->variant() != VAR_REF)
{
if (left->variant() == VAR_REF && left->symbol() && left->symbol()->identifier() == mapToLoop[z])
{
int const rVar = right->variant();
if (rVar == CONST_REF || isSgValueExp(right))
isSimple[z] = true;
}
}
}
}
}
}
isConv = converToDDOT(isSimple, subs);
}
else
{
set<string> loopVars;
copyExpr = (SgArrayRefExp*)(expr->copyPtr());
SgExpression* subs = copyExpr->subscripts();
vector<bool> isSimple = isSimpleRef(loop->loop->GetOriginal(), loop->loop->GetOriginal()->lastNodeOfStmt(), subs, { ARRAY_OP, ARRAY_REF, VAR_REF }, funcMap, loopVars);
isConv = converToDDOT(isSimple, subs);
}
string remoteExp(copyExpr->unparse());
auto rem = uniqRemotes.find(remoteExp);
if (rem == uniqRemotes.end() && remotesInParallel.find(remoteExp) == remotesInParallel.end())
{
rem = uniqRemotes.insert(rem, make_pair(remoteExp, new ArrayRefExp(copyExpr)));
addedRemotes.insert(new ArrayRefExp(copyExpr));
if (line > 0 && !isConv)
{
string remoteExp(expr->unparse());
__spf_print(1, "WARN: added remote access for array ref '%s' on line %d can significantly reduce performance\n", remoteExp.c_str(), line);
wstring bufE, bufR;
__spf_printToLongBuf(bufE, L"Added remote access for array ref '%s' can significantly reduce performance", to_wstring(remoteExp).c_str());
__spf_printToLongBuf(bufR, R129, to_wstring(remoteExp).c_str());
messages.push_back(Messages(WARR, line, bufR, bufE, 3009));
}
}
}
ArrayRefExp* createRemoteLink(const LoopGraph* currLoop, const DIST::Array* forArray)
{
SgFile* file = current_file;
const set<string> allFiles = getAllFilesInProject();
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
SgExpression* ex = new SgExpression(EXPR_LIST);
SgExpression* p = ex;
for (int z = 0; z < forArray->GetDimSize(); ++z)
{
p->setLhs(new SgExpression(DDOT));
if (z != forArray->GetDimSize() - 1)
{
p->setRhs(new SgExpression(EXPR_LIST));
p = p->rhs();
}
}
SgArrayRefExp* newRem = new SgArrayRefExp(*((SgSymbol*)forArray->GetNameInLocationS(realStat)), *ex);
return new ArrayRefExp(newRem);
}
void addRemotesToDir(const pair<SgForStmt*, LoopGraph*> *under_dvm_dir, const map<string, ArrayRefExp*> &uniqRemotes)
{
SgStatement *dir = under_dvm_dir->first->lexPrev();
if (!dir)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (dir->variant() != DVM_PARALLEL_ON_DIR)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* list = NULL;
for (auto& rem : uniqRemotes)
{
list = dir->expr(1);
SgExpression *remoteList = NULL;
while (list)
{
if (list->lhs())
{
if (list->lhs()->variant() == REMOTE_ACCESS_OP)
{
remoteList = list->lhs();
break;
}
}
list = list->rhs();
}
list = dir->expr(1);
SgExpression *toAdd = new SgExpression(EXPR_LIST, rem.second, NULL, NULL);
if (!remoteList)
{
remoteList = new SgExpression(REMOTE_ACCESS_OP, toAdd, NULL, NULL);
dir->setExpression(1, *new SgExpression(EXPR_LIST, remoteList, list, NULL));
}
else
{
SgExpression *lastLhs = remoteList->lhs();
remoteList->setLhs(toAdd);
toAdd->setRhs(lastLhs);
}
}
list = dir->expr(1);
while (list)
{
if (list->lhs())
{
if (list->lhs()->variant() == REMOTE_ACCESS_OP)
{
remoteAggregation(list->lhs(), NULL);
break;
}
}
list = list->rhs();
}
}
static set<SgStatement*> addedDummyIf;
void groupActualAndRemote(SgFile *file, bool revert)
{
if (revert)
{
for (auto& dummyIf : addedDummyIf)
{
SgStatement* end = dummyIf->lastNodeOfStmt();
SgStatement* st = dummyIf->lexNext();
while (st != end)
{
SgStatement* toMove = st;
st = st->lexNext();
toMove = toMove->extractStmt();
dummyIf->insertStmtBefore(*toMove, *dummyIf->controlParent());
}
dummyIf->extractStmt();
}
addedDummyIf.clear();
}
else
{
SgStatement* st = file->firstStatement();
while (st)
{
if (st->variant() == DVM_REMOTE_ACCESS_DIR)
{
auto next = st->lexNext();
auto nnext = next->lexNext();
if (next->expr(0) &&
next->expr(0)->variant() == ARRAY_REF &&
next->expr(0)->lhs()->lhs()->variant() == ARRAY_REF)
{
const string ref = next->expr(0)->unparse();
if (nnext && nnext->variant() == ACC_ACTUAL_DIR)
{
bool ifFullActual = false;
SgExpression* ex = nnext->expr(0);
while (ex)
{
if (ex->lhs() && ex->lhs()->unparse() == ref)
{
ifFullActual = true;
break;
}
ex = ex->rhs();
}
bool ifRemote = false;
ex = st->expr(0);
while (ex)
{
if (ex->lhs() && ref.find(ex->lhs()->unparse()) != string::npos)
{
ifRemote = true;
break;
}
ex = ex->rhs();
}
if (ifRemote && ifFullActual)
{
SgStatement* op = next->extractStmt();
SgStatement* act = nnext->extractStmt();
SgIfStmt* dummyIf = new SgIfStmt(*new SgValueExp(true), *act);
st->insertStmtAfter(*dummyIf, *st->controlParent());
dummyIf->insertStmtAfter(*op, *dummyIf);
addedDummyIf.insert(dummyIf);
st = dummyIf->lastNodeOfStmt();
}
}
}
}
st = st->lexNext();
}
}
}