#include "parse_merge_dirs.h" #include using std::map; using std::pair; using std::string; using std::unordered_map; using std::unordered_set; using std::vector; static void parseMergeDirective(const char *comment, vector> &parsed_mapping) { while (comment) { auto *line_end = strchr(comment, '\n'); static const char prefix[] = "!!SPF TRANSFORM(MERGE_ARRAYS("; static const auto compare_chars = sizeof(prefix) - 1; if (strncasecmp(comment, prefix, compare_chars) == 0) { auto *pair_start = comment + compare_chars; auto *comma = strchr(pair_start, ','); if (comma) { auto *close_br = strchr(comma + 1, ')'); if (close_br) { parsed_mapping.emplace_back( string(pair_start, comma - pair_start), string(comma + 1, close_br - comma - 1)); } } } comment = line_end; if (comment) comment++; } } static string getNonDefaultRegion(DIST::Array *a) { string result; if (!a) return result; for (const auto ®_name : a->GetRegionsName()) { if (reg_name != "default") { if (!result.empty()) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); result = reg_name; } } return result; } static bool hasSameSizes(DIST::Array *a, DIST::Array *b) { for (auto *array : {a, b}) { for (const auto &p : array->GetSizes()) { if (p.first < 0 || p.second < 0) return false; } } return a->GetSizes() == b->GetSizes() && a->GetTypeSize() == b->GetTypeSize(); } static bool checkSimilarTemplates(vector ®ions, const unordered_map &new_region_mapping) { // new region -> old regions unordered_map> new_region_inverse_mapping; for (const auto &p : new_region_mapping) new_region_inverse_mapping[p.second].insert(p.first); for (const auto &new_reg : new_region_inverse_mapping) { DIST::Array *template_array = nullptr; string first_reg_name; for (const auto &old_region_name : new_reg.second) { auto *old_reg = getRegionByName(regions, old_region_name); if (!old_reg) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); const auto &distr_rules = old_reg->GetDataDir().GetDistrRules(); if (distr_rules.size() != 1) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); auto *current_template = distr_rules.front().first; if (template_array) { if (!hasSameSizes(template_array, current_template)) { __spf_print(1, "Templates of %s and %s has different sizes\n", first_reg_name.c_str(), old_region_name.c_str()); return false; } // else everything OK } else { template_array = current_template; first_reg_name = old_region_name; } } } return true; } static bool hasSameAlignment(const std::unordered_set &align_a, const std::unordered_set &align_b) { if (align_a.size() != 1 || align_b.size() != 1) return false; const auto *rule_a = *align_a.begin(); const auto *rule_b = *align_b.begin(); if (rule_a->alignRule != rule_b->alignRule) return false; return true; } static void printExpr(SgExpression *e, string pad) { if (!e) return; __spf_print(1, "%s%d: %s\n", pad.c_str(), e->variant(), e->unparse()); printExpr(e->lhs(), pad + " "); printExpr(e->rhs(), pad + " "); } static pair, SgSymbol *> generateDeclaration(const string &array_name, const string &common_block_name, const vector> &sizes, SgType *type, SgStatement *scope) { auto *array_symbol = new SgSymbol(VARIABLE_NAME, array_name.c_str(), new SgType(T_ARRAY), scope); auto *decl = new SgDeclarationStatement(VAR_DECL); decl->setExpression(1, new SgTypeExp(*type)); SgExpression *subs = new SgExprListExp(); auto *array_ref = new SgArrayRefExp(*array_symbol, *subs); for (int i = 0; i < sizes.size(); i++) { const auto &p = sizes[i]; auto *d = new SgExpression(DDOT, new SgValueExp(p.first), new SgValueExp(p.second)); subs->setLhs(d); if (i + 1 < sizes.size()) { subs->setRhs(new SgExprListExp()); subs = subs->rhs(); } } decl->setExpression(0, array_ref); auto comm = new SgStatement(COMM_STAT); comm->setExpression(0, new SgExpression(COMM_LIST, new SgVarRefExp(array_symbol), NULL, new SgSymbol(COMMON_NAME, common_block_name.c_str()))); return {{decl, comm}, array_symbol}; } static SgExpression* findExprWithVariant(SgExpression* exp, int variant) { if (exp) { if (exp->variant() == variant) return exp; auto *l = findExprWithVariant(exp->lhs(), variant); if (l) return l; auto *r = findExprWithVariant(exp->rhs(), variant); if (r) return r; } return NULL; } SgType* GetArrayType(DIST::Array *array) { if (!array) return NULL; for (const auto& decl_place : array->GetDeclInfo()) { if (SgFile::switchToFile(decl_place.first) != -1) { auto* decl = SgStatement::getStatementByFileAndLine(decl_place.first, decl_place.second); if (decl) { for (int i = 0; i < 3; i++) { auto* found_type = isSgTypeExp(findExprWithVariant(decl->expr(i), TYPE_OP)); if (found_type) return found_type->type(); } } } } return NULL; } SgSymbol *insertDeclIfNeeded(const string &array_name, const string &common_block_name, DIST::Array *example_array, FuncInfo *dest, unordered_map> &inserted_arrays) { auto *type = GetArrayType(example_array); if (!type) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); if (SgFile::switchToFile(dest->fileName) == -1) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); auto &by_func = inserted_arrays[dest]; auto it = by_func.find(array_name); if (it != by_func.end()) return it->second; SgStatement *st = dest->funcPointer; auto *end = st->lastNodeOfStmt(); st = st->lexNext(); while (st != end && !isSgExecutableStatement(st)) { st = st->lexNext(); } auto generated = generateDeclaration(array_name, common_block_name, example_array->GetSizes(), type, dest->funcPointer); for (auto *new_stmt : generated.first) st->insertStmtBefore(*new_stmt, *dest->funcPointer); by_func[array_name] = generated.second; return generated.second; } static pair createNewArray(DIST::Array *example_array, const string &base_name, const map> &allFuncInfo, unordered_map> &inserted_arrays) { auto common_block_name = base_name + "_merge_cb"; auto array_name = base_name; for (const auto &by_file : allFuncInfo) { for (auto *func_info : by_file.second) { if (func_info->isMain) { insertDeclIfNeeded( array_name, common_block_name, example_array, func_info, inserted_arrays); } } } return std::make_pair(array_name, common_block_name); } static void replaceArrayRec(SgExpression *e, const unordered_set &arrays_to_replace, SgSymbol **func_symbol_hint, const pair &replace_by, DIST::Array *example_array, FuncInfo *func, unordered_map> &inserted_arrays) { if (!e) return; if (isArrayRef(e) && arrays_to_replace.find(e->symbol()->identifier()) != arrays_to_replace.end()) { if (!(*func_symbol_hint)) { *func_symbol_hint = insertDeclIfNeeded( replace_by.first, replace_by.second, example_array, func, inserted_arrays); } e->setSymbol(*func_symbol_hint); } replaceArrayRec( e->lhs(), arrays_to_replace, func_symbol_hint, replace_by, example_array, func, inserted_arrays); replaceArrayRec( e->rhs(), arrays_to_replace, func_symbol_hint, replace_by, example_array, func, inserted_arrays); } static void replaceRegion(SgStatement* st, const unordered_map &new_region_mapping) { if (!st) return; if(isSPF_stat(st) && st->variant() == SPF_PARALLEL_REG_DIR) { auto it = new_region_mapping.find(st->symbol()->identifier()); if (it != new_region_mapping.end()) st->setSymbol(*(new SgSymbol(CONST_NAME, it->second.c_str()))); } } void mergeCopyArrays(vector ®ions, const map> &allFuncInfo) { for (const auto *region : regions) { __spf_print(1, "region %s\n", region->GetName().c_str()); const auto &dirs = region->GetDataDir(); __spf_print(1, " distr rules: %d\n", dirs.distrRules.size()); const auto ¤tVariant = region->GetCurrentVariant(); int distr_idx = 0; for (const auto &distr : dirs.distrRules) { const auto &dist_rule = distr.second.back().distRule; string sizes; for (const auto &p : distr.first->GetSizes()) { if (!sizes.empty()) sizes.push_back(','); sizes += std::to_string(p.first) + ":" + std::to_string(p.second); } __spf_print(1, " DIST %s(%s)", distr.first->GetName().c_str(), sizes.c_str()); for (const auto &dim : dist_rule) __spf_print(1, " %c", dim == dist::BLOCK ? 'B' : '*'); __spf_print(1, "\n"); distr_idx++; } __spf_print(1, " align rules: %d\n", dirs.alignRules.size()); for (const auto &align : dirs.alignRules) { string sub_a, sub_b; int i = 0; for (const auto coefs : align.alignRule) { if (!sub_a.empty()) sub_a.push_back(','); sub_a += std::to_string(coefs.first) + "*i" + std::to_string(i) + "+" + std::to_string(coefs.second); i++; } for (const auto coefs : align.alignRuleWith) { if (!sub_b.empty()) sub_b.push_back(','); sub_b += std::to_string(coefs.second.first) + "*i" + std::to_string(coefs.first) + "+" + std::to_string(coefs.second.second); } __spf_print(1, " ALIGN %s(%s) WITH %s(%s)\n", align.alignArray->GetName().c_str(), sub_a.c_str(), align.alignWith->GetName().c_str(), sub_b.c_str()); } } // parse directives // new array name -> current arrays unordered_map> arrays_to_merge; unordered_map> array_alignment; for (const auto &by_file : allFuncInfo) { const auto current_file_name = by_file.first; if (SgFile::switchToFile(current_file_name) == -1) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); for (auto *func_info : by_file.second) { SgStatement *curr_stmt = func_info->funcPointer; if (!curr_stmt) continue; auto *stmt_end = curr_stmt->lastDeclaration(); if (!stmt_end) continue; stmt_end = stmt_end->lexNext(); for (; curr_stmt && curr_stmt != stmt_end; curr_stmt = curr_stmt->lexNext()) { if (curr_stmt->comments()) { vector> parsed_mapping; parseMergeDirective(curr_stmt->comments(), parsed_mapping); for (const auto &p : parsed_mapping) { auto *found_array = getArrayFromDeclarated(curr_stmt, p.first); if (found_array) { arrays_to_merge[p.second].insert(found_array); array_alignment[found_array] = {}; } } } } } } // find alignment rules for array for (const auto *region : regions) { const auto &dirs = region->GetDataDir(); for (const auto &align : dirs.alignRules) { auto it = array_alignment.find(align.alignArray); if (it != array_alignment.end()) it->second.insert(&align); } } // old region -> new region unordered_map new_region_mapping; // new array -> new region unordered_map arrays_new_region_mapping; vector created_region_names; for (const auto &by_new_array : arrays_to_merge) { string new_region_name; for (auto *current_array : by_new_array.second) { auto current_array_region = getNonDefaultRegion(current_array); auto it = new_region_mapping.find(current_array_region); if (it != new_region_mapping.end()) { if (new_region_name.empty()) new_region_name = it->second; else if (new_region_name != it->second) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); } } if (new_region_name.empty()) { new_region_name = "merged_reg_" + std::to_string(created_region_names.size()); created_region_names.push_back(new_region_name); } for (auto *current_array : by_new_array.second) { auto current_array_region = getNonDefaultRegion(current_array); new_region_mapping[current_array_region] = new_region_name; } arrays_new_region_mapping[by_new_array.first] = new_region_name; } if (!checkSimilarTemplates(regions, new_region_mapping)) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); unordered_map> inserted_arrays; for (const auto &by_dest_array : arrays_to_merge) { const auto ©_arrays = by_dest_array.second; if (copy_arrays.empty()) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); auto *first_element = *copy_arrays.begin(); auto first_elem_rules_it = array_alignment.find(first_element); if (first_elem_rules_it == array_alignment.end()) continue; const auto &first_elem_rules = first_elem_rules_it->second; for (auto *array_to_merge : copy_arrays) { auto array_rules_it = array_alignment.find(array_to_merge); if (array_rules_it == array_alignment.end()) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); const auto &array_rules = array_rules_it->second; if (!hasSameSizes(array_to_merge, first_element) || !hasSameAlignment(first_elem_rules, array_rules)) { __spf_print(1, "Arrays %s and %s has different sizes or align rules\n", array_to_merge->GetName().c_str(), first_element->GetName().c_str()); printInternalError(convertFileName(__FILE__).c_str(), __LINE__); } } __spf_print(1, "merge into %s (%s):\n", by_dest_array.first.c_str(), arrays_new_region_mapping[by_dest_array.first].c_str()); for (auto *array_to_merge : copy_arrays) __spf_print(1, "%s\n", array_to_merge->GetName().c_str()); auto created_array_info = createNewArray(first_element, by_dest_array.first, allFuncInfo, inserted_arrays); unordered_set arrays_to_replace; for (auto *array_to_merge : copy_arrays) arrays_to_replace.insert(array_to_merge->GetShortName()); for (const auto &by_file : allFuncInfo) { if (SgFile::switchToFile(by_file.first) == -1) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); for (auto *func_info : by_file.second) { SgSymbol *func_symbol_hint = nullptr; SgStatement *st = func_info->funcPointer; auto *func_end = st->lastNodeOfStmt(); st = st->lexNext(); while (st && !isSgExecutableStatement(st) && st != func_end) st = st->lexNext(); while (st && st != func_end) { for (int i = 0; i < 3; i++) { replaceArrayRec( st->expr(i), arrays_to_replace, &func_symbol_hint, created_array_info, first_element, func_info, inserted_arrays); } replaceRegion(st, new_region_mapping); st = st->lexNext(); } } } } }