From 413daa2aea95deca326aa44a7aadde92782d2c41 Mon Sep 17 00:00:00 2001 From: Mikhail Kocharmin Date: Thu, 31 Oct 2024 00:03:32 +0300 Subject: [PATCH] replace_dist_arrays_in_io: insert only necessary copy statements --- .../replace_dist_arrays_in_io.cpp | 97 +++++++++++++++---- 1 file changed, 80 insertions(+), 17 deletions(-) diff --git a/sapfor/experts/Sapfor_2017/_src/Transformations/replace_dist_arrays_in_io.cpp b/sapfor/experts/Sapfor_2017/_src/Transformations/replace_dist_arrays_in_io.cpp index bc4b9b8..da6cfaa 100644 --- a/sapfor/experts/Sapfor_2017/_src/Transformations/replace_dist_arrays_in_io.cpp +++ b/sapfor/experts/Sapfor_2017/_src/Transformations/replace_dist_arrays_in_io.cpp @@ -126,16 +126,70 @@ static void populateDistributedIoArrays(map>& ar __spf_print(DEBUG_TRACE, "[replace]\n"); } -static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgExpression* exp) +static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgExpression* exp, bool& has_read, bool& has_write, bool from_read, bool from_write) { if (!exp) return; if (exp->symbol() && strcmp(exp->symbol()->identifier(), arr->identifier()) == 0) + { + has_read |= from_read; + has_write |= from_write; exp->setSymbol(replace_by); + } - replaceArrayRec(arr, replace_by, exp->lhs()); - replaceArrayRec(arr, replace_by, exp->rhs()); + switch (exp->variant()) + { + case FUNC_CALL: + { + replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, true, false); + replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, true, true); + break; + } + case EXPR_LIST: + { + replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, from_read, from_write); + replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, from_read, from_write); + break; + } + default: + { + replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, true, false); + replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, true, false); + break; + } + } +} + +static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgStatement* st, bool& has_read, bool& has_write) +{ + if (!st) + return; + + switch (st->variant()) + { + case ASSIGN_STAT: + case READ_STAT: + { + replaceArrayRec(arr, replace_by, st->expr(0), has_read, has_write, false, true); + replaceArrayRec(arr, replace_by, st->expr(1), has_read, has_write, true, false); + break; + } + case PROC_STAT: + case FUNC_STAT: + { + replaceArrayRec(arr, replace_by, st->expr(0), has_read, has_write, true, false); + replaceArrayRec(arr, replace_by, st->expr(1), has_read, has_write, true, true); + break; + } + default: + { + for (int i = 0; i < 3; i++) + replaceArrayRec(arr, replace_by, st->expr(i), has_read, has_write, true, false); + + break; + } + } } static void copyArrayBetweenStatements(SgSymbol* replace_symb, SgSymbol* replace_by, SgStatement* start, SgStatement* last) @@ -144,24 +198,33 @@ static void copyArrayBetweenStatements(SgSymbol* replace_symb, SgSymbol* replace start = start->lexNext(); auto* stop = last->lexNext(); + + bool has_read = false, has_write = false; + for (auto* st = start; st != stop; st = st->lexNext()) - for (int i = 0; i < 3; i++) - replaceArrayRec(replace_symb, replace_by, st->expr(i)); + replaceArrayRec(replace_symb, replace_by, st, has_read, has_write); - // A_copy = A - SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_by), *new SgArrayRefExp(*replace_symb)); - assign->setlineNumber(getNextNegativeLineNumber()); // before region - auto* parent = start->controlParent(); - if (parent && parent->lastNodeOfStmt() == start) - parent = parent->controlParent(); - start->insertStmtAfter(*assign, *parent); + if (has_read) + { + // A_copy = A + SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_by), *new SgArrayRefExp(*replace_symb)); + assign->setlineNumber(getNextNegativeLineNumber()); // before region + auto* parent = start->controlParent(); + if (parent && parent->lastNodeOfStmt() == start) + parent = parent->controlParent(); - // A = A_reg - assign = new SgAssignStmt(*new SgArrayRefExp(*replace_symb), *new SgArrayRefExp(*replace_by)); - //TODO: bug with insertion - //assign->setlineNumber(getNextNegativeLineNumber()); // after region - last->insertStmtBefore(*assign, *(last->controlParent())); + start->insertStmtAfter(*assign, *parent); + } + + if (has_write) + { + // A = A_reg + SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_symb), *new SgArrayRefExp(*replace_by)); + //TODO: bug with insertion + //assign->setlineNumber(getNextNegativeLineNumber()); // after region + last->insertStmtBefore(*assign, *(last->controlParent())); + } } static void replaceArrayInFragment(DIST::Array* arr, const set usages, SgSymbol* replace_by, SgStatement* start, SgStatement* last, const string& filename)