improved MERGE_ARRAYS

This commit is contained in:
ALEXks
2026-05-03 21:00:07 +03:00
parent 51f97e2be9
commit 5a3c936b5c
8 changed files with 77 additions and 75 deletions

View File

@@ -163,10 +163,9 @@ set(PARALLEL_REG src/ParallelizationRegions/ParRegions.cpp
src/ParallelizationRegions/resolve_par_reg_conflicts.cpp
src/ParallelizationRegions/resolve_par_reg_conflicts.h
src/ParallelizationRegions/uniq_name_creator.cpp
src/ParallelizationRegions/uniq_name_creator.h)
set(MERGE_REGIONS src/ParallelizationRegions/merge_regions.cpp
src/ParallelizationRegions/merge_regions.h)
src/ParallelizationRegions/uniq_name_creator.h
src/ParallelizationRegions/merge_regions.h
src/ParallelizationRegions/merge_regions.cpp)
set(TR_DEAD_CODE src/Transformations/DeadCodeRemoving/dead_code.cpp
src/Transformations/DeadCodeRemoving/dead_code.h)
@@ -430,7 +429,6 @@ set(SOURCE_EXE
${LOOP_ANALYZER}
${TRANSFORMS}
${PARALLEL_REG}
${MERGE_REGIONS}
${PRIV}
${FDVM}
${OMEGA}
@@ -484,7 +482,6 @@ source_group (GraphCall FILES ${GR_CALL})
source_group (GraphLoop FILES ${GR_LOOP})
source_group (LoopAnalyzer FILES ${LOOP_ANALYZER})
source_group (ParallelizationRegions FILES ${PARALLEL_REG})
source_group (MergeRegions FILES ${MERGE_REGIONS})
source_group (PrivateAnalyzer FILES ${PRIV})
source_group (FDVM_Compiler FILES ${FDVM})
source_group (SageExtension FILES ${OMEGA})

View File

@@ -2,7 +2,6 @@
#include <cstdio>
#include <cstring>
#include <cstring>
#include <cstdlib>
#include <string>
#include <vector>
@@ -584,6 +583,37 @@ void fillCheckpointFromComment(Statement *stIn, map<int, Expression*> &clauses,
template void fillCheckpointFromComment(Statement *stIn, map<int, Expression*> &clauses, set<Symbol*> &vars, set<Symbol*> &expt);
template void fillCheckpointFromComment(Statement *stIn, map<int, Expression*> &clauses, set<string> &vars, set<string> &expt);
template<typename fillType>
void fillMergeArraysFromComment(Statement* stIn, pair<fillType, fillType>& toReplace)
{
if (stIn)
{
SgStatement* st = stIn->GetOriginal();
if (st->variant() == SPF_TRANSFORM_DIR)
{
SgExpression* exprList = st->expr(0);
while (exprList)
{
const int var = exprList->lhs()->variant();
if (var == SPF_MERGE_ARRAYS_OP)
{
if (!exprList->lhs()->lhs() || !exprList->lhs()->rhs())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
fillType *dummy = NULL;
toReplace = std::make_pair(getData(exprList->lhs()->lhs(), dummy),
getData(exprList->lhs()->rhs(), dummy));
break;
}
exprList = exprList->rhs();
}
}
}
}
template void fillMergeArraysFromComment(Statement* st, pair<string, string>& toReplace);
template void fillMergeArraysFromComment(Statement* st, pair<Symbol*, Symbol*>& toReplace);
void fillInfoFromDirectives(const LoopGraph *loopInfo, ParallelDirective *directive)
{
SgForStmt *currentLoop = (SgForStmt*)loopInfo->loop;

View File

@@ -52,4 +52,7 @@ void fillShrinkFromComment(Statement *stIn, std::vector<std::pair<fillType, std:
template<typename fillType>
void fillCheckpointFromComment(Statement *stIn, std::map<int, Expression*> &clauses, std::set<fillType> &vars, std::set<fillType> &expt);
template<typename fillType>
void fillMergeArraysFromComment(Statement* stIn, std::pair<fillType, fillType>& toReplace);
int getCoverPropertyFromComment(Statement* stIn);

View File

@@ -186,12 +186,13 @@ static bool checkCover(SgStatement* st,
return retVal;
}
static bool checkProcessPrivate(SgStatement* st,
SgStatement* attributeStatement,
const set<Symbol*>& privates,
vector<Messages>& messagesForFile)
static bool checkDeclaration(SgStatement* st,
SgStatement* attributeStatement,
const set<Symbol*>& variables,
vector<Messages>& messagesForFile)
{
// PROCESS_PRIVATE(VAR)
// MERGE_ARRAYS(ARR1, ARR2)
const int var = st->variant();
bool retVal = true;
@@ -203,12 +204,12 @@ static bool checkProcessPrivate(SgStatement* st,
set<string> varDef, varUse;
fillVarsSets(iterator, end, varDef, varUse);
for (auto& privElemS : privates)
for (auto& var : variables)
{
const string privElem = privElemS->GetOriginal()->identifier();
const string varElem = var->GetOriginal()->identifier();
bool defCond = true;
if (varDef.find(privElem) == varDef.end())
if (varDef.find(varElem) == varDef.end())
defCond = false;
if (!defCond)
@@ -1784,7 +1785,7 @@ static inline bool processStat(SgStatement *st, const string &currFile,
fillPrivatesFromComment(new Statement(attributeStatement), privates, SPF_PROCESS_PRIVATE_OP);
if (privates.size())
{
bool result = checkProcessPrivate(st, attributeStatement, privates, messagesForFile);
bool result = checkDeclaration(st, attributeStatement, privates, messagesForFile);
retVal = retVal && result;
}
@@ -1938,11 +1939,20 @@ static inline bool processStat(SgStatement *st, const string &currFile,
if (isSPF_OP(attributeStatement, SPF_MERGE_ARRAYS_OP))
{
attributeStatement->setLocalLineNumber(-1);
/*if (st->variant() != FOR_NODE)
if (!isSgDeclarationStatement(st))
{
BAD_POSITION_FULL(ERROR, "", "", "before", RR1_1, "DO statement", RR1_3, attributeStatement->lineNumber());
BAD_POSITION_FULL(ERROR, "", "", "before", RR1_1, "declataion statement", RR1_9, attributeStatement->lineNumber());
retVal = false;
}*/
}
else
{
pair<Symbol*, Symbol*> toReplacePair;
fillMergeArraysFromComment(new Statement(attributeStatement), toReplacePair);
set<Symbol*> toCheckDecl = { toReplacePair.first };
bool result = checkDeclaration(st, attributeStatement, toCheckDecl, messagesForFile);
retVal = retVal && result;
}
}
}
else if (type == SPF_CHECKPOINT_DIR)

View File

@@ -3,6 +3,7 @@
#include <set>
#include <map>
#include "merge_regions.h"
#include "../DirectiveProcessing/directive_parser.h"
using std::map;
using std::set;
@@ -10,46 +11,6 @@ using std::pair;
using std::string;
using std::vector;
//TODO: need to create new clause!!
static void parseMergeDirective(const char *comment,
vector<pair<string, string>> &parsed_mapping)
{
while (comment)
{
auto *line_end = strchr(comment, '\n');
static const char prefix[] = "!!spf transform(merge_arrays(";
static const auto compare_chars = sizeof(prefix) - 1;
if (strlen(comment) >= compare_chars)
{
std::string comment_cmp(comment, compare_chars);
convertToLower(comment_cmp);
if (comment_cmp == prefix)
{
auto* pair_start = comment + compare_chars;
auto* comma = strchr(pair_start, ',');
if (comma)
{
auto* close_br = strchr(comma + 1, ')');
if (close_br)
{
parsed_mapping.emplace_back(
string(pair_start, comma - pair_start),
string(comma + 1, close_br - comma - 1));
}
}
}
}
comment = line_end;
if (comment)
comment++;
}
}
static string getNonDefaultRegion(DIST::Array *a)
{
string result;
@@ -217,7 +178,7 @@ static SgExpression* findExprWithVariant(SgExpression* exp, int variant)
return NULL;
}
SgType* GetArrayType(DIST::Array *array)
static SgType* GetArrayType(DIST::Array *array)
{
if (!array)
return NULL;
@@ -242,11 +203,11 @@ SgType* GetArrayType(DIST::Array *array)
return NULL;
}
SgSymbol *insertDeclIfNeeded(const string &array_name,
const string &common_block_name,
DIST::Array *example_array,
FuncInfo *dest,
map<FuncInfo *, map<string, SgSymbol *>> &inserted_arrays)
static SgSymbol *insertDeclIfNeeded(const string &array_name,
const string &common_block_name,
DIST::Array *example_array,
FuncInfo *dest,
map<FuncInfo *, map<string, SgSymbol *>> &inserted_arrays)
{
auto *type = GetArrayType(example_array);
@@ -462,17 +423,20 @@ void mergeRegions(vector<ParallelRegion *> &regions, const map<string, vector<Fu
for (; curr_stmt && curr_stmt != stmt_end; curr_stmt = curr_stmt->lexNext())
{
if (curr_stmt->comments())
auto attributes = getAttributes<SgStatement*, SgStatement*>(curr_stmt, set<int>{SPF_TRANSFORM_DIR});
for (auto& attr : attributes)
{
vector<pair<string, string>> parsed_mapping;
parseMergeDirective(curr_stmt->comments(), parsed_mapping);
pair<string, string> parsed_mapping = { "", "" };
const auto empty = parsed_mapping;
fillMergeArraysFromComment(new Statement(attr), parsed_mapping);
for (const auto &p : parsed_mapping)
if (parsed_mapping != empty)
{
auto *found_array = getArrayFromDeclarated(curr_stmt, p.first);
auto* found_array = getArrayFromDeclarated(curr_stmt, parsed_mapping.first);
if (found_array)
{
arrays_to_merge[p.second].insert(found_array);
arrays_to_merge[parsed_mapping.second].insert(found_array);
array_alignment[found_array] = {};
}
}

View File

@@ -521,7 +521,7 @@ static void processLoopBound(SgStatement* st,
getBorderVars(exp, st->fileName(), borderVars);
if (containsArrayRefRecursive(exp), borderVars, st->fileName())
if (containsArrayRefRecursive(exp))
{
copyStatement(st);
@@ -575,7 +575,6 @@ void arrayConstantPropagation(SgProject& project)
processLoopBound(st, st->expr(0), upperBoundUnparsed, true, arrayToVariable, borderVars);
processLoopBound(st, st->expr(0), lowerBoundUnparsed, false, arrayToVariable, borderVars);
}
}
}
@@ -601,9 +600,7 @@ void arrayConstantPropagation(SgProject& project)
{
SgFile::switchToFile(fileName);
for (SgStatement* st : statements)
{
insertCommonAndDeclsForFunction(st, variablesToAdd);
}
}
map<string, map<SgStatement*, vector<pair<string, string>>>> result;

View File

@@ -322,6 +322,7 @@ static const wchar_t *RR1_5 = L"RR1_5:";
static const wchar_t *RR1_6 = L"RR1_6:";
static const wchar_t *RR1_7 = L"RR1_7:";
static const wchar_t *RR1_8 = L"RR1_8:";
static const wchar_t *RR1_9 = L"RR1_9:";
static const wchar_t *R2 = L"R2:";
static const wchar_t *R3 = L"R3:";

View File

@@ -1,3 +1,3 @@
#pragma once
#define VERSION_SPF "2485"
#define VERSION_SPF "2486"