add intersection and difference operations for ArrayDimension

This commit is contained in:
2024-12-12 02:42:39 +03:00
parent c0c6ed9131
commit c55eabf0ad

View File

@@ -3,6 +3,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <numeric>
#include <iostream> #include <iostream>
#include "private_arrays_search.h" #include "private_arrays_search.h"
@@ -181,7 +182,7 @@ int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAccessingInd
pair currentCoefs = coefsForDims.back(); pair currentCoefs = coefsForDims.back();
ArrayDimension current_dim; ArrayDimension current_dim;
if (var->getType() == SAPFOR::CFG_ARG_TYPE::CONST) { if (var->getType() == SAPFOR::CFG_ARG_TYPE::CONST) {
current_dim = { stoul(var->getValue()), 0, 1 }; current_dim = { stoul(var->getValue()), 1, 1 };
} }
else else
{ {
@@ -231,10 +232,97 @@ int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAccessingInd
} }
vector<uint64_t> FindParticularSolution(const ArrayDimension& dim1, const ArrayDimension& dim2)
{
for (uint64_t i = 0; i < dim1.tripCount; i++)
{
uint64_t leftPart = dim1.start + i * dim1.step;
for (uint64_t j = 0; j < dim2.tripCount; j++)
{
uint64_t rightPart = dim2.start + j * dim2.step;
if (leftPart == rightPart)
{
return {i, j};
}
}
}
return {};
}
/* dim1 /\ dim2 */
ArrayDimension* IntersectDimension(const ArrayDimension& dim1, const ArrayDimension& dim2)
{
vector<uint64_t> partSolution = FindParticularSolution(dim1, dim2);
if (partSolution.empty())
{
return NULL;
}
int64_t x0 = partSolution[0], y0 = partSolution[1];
/* x = x_0 + c * t */
/* y = y_0 + d * t */
int64_t c = dim2.step / gcd(dim1.step, dim2.step);
int64_t d = dim1.step / gcd(dim1.step, dim2.step);
int64_t tXMin, tXMax, tYMin, tYMax;
tXMin = -x0 / c;
tXMax = (dim1.tripCount - 1 - x0) / c;
tYMin = -y0 / d;
tYMax = (dim2.tripCount - 1 - y0) / d;
int64_t tMin = max(tXMin, tYMin);
uint64_t tMax = min(tXMax, tYMax);
if (tMin > tMax)
{
return NULL;
}
uint64_t start3 = dim1.start + x0 * dim1.step;
uint64_t step3 = c * dim1.step;
ArrayDimension* result = new(ArrayDimension){ start3, step3, tMax + 1 };
return result;
}
/* dim1 / dim2 */
vector<ArrayDimension> DimensionDifference(const ArrayDimension& dim1, const ArrayDimension& dim2)
{
ArrayDimension* intersection = IntersectDimension(dim1, dim2);
if (!intersection)
{
return {dim1, dim2};
}
vector<ArrayDimension> result;
/* add the part before intersection */
if (dim1.start < intersection->start)
{
result.push_back({ dim1.start, dim1.step, (intersection->start - dim1.start) / dim1.step });
}
/* add the parts between intersection steps */
uint64_t start = (intersection->start - dim1.start) / dim1.step;
uint64_t interValue = intersection->start;
for (int64_t i = start; dim1.start + i * dim1.step <= intersection->start + intersection->step * (intersection->tripCount - 1); i++)
{
uint64_t centerValue = dim1.start + i * dim1.step;
if (centerValue == interValue)
{
if (i - start > 1)
{
result.push_back({ dim1.start + (start + 1) * dim1.step, dim1.step, i - start - 1 });
start = i;
}
interValue += intersection->step;
}
}
/* add the part after intersection */
if (intersection->start + intersection->step * (intersection->tripCount - 1) < dim1.start + dim1.step * (dim1.tripCount - 1))
{
/* first value after intersection */
uint64_t right_start = intersection->start + intersection->step * (intersection->tripCount - 1) + dim1.step;
uint64_t tripCount = (dim1.start + dim1.step * dim1.tripCount - right_start) / dim1.step;
result.push_back({right_start, dim1.step, tripCount});
}
delete(intersection);
return result;
}
void FindPrivateArrays(map<string, vector<LoopGraph*>> &loopGraph, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) void FindPrivateArrays(map<string, vector<LoopGraph*>> &loopGraph, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{ {
cout << "FindPrivateArrays\n";
for (const auto& curr_graph_pair: loopGraph) for (const auto& curr_graph_pair: loopGraph)
{ {
for (const auto& curr_loop : curr_graph_pair.second) for (const auto& curr_loop : curr_graph_pair.second)