Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression Tree for ranking solutions #1354

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tensilelite/Tensile/Contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,12 @@ class SizeMapping:
'workGroupMappingXCC',
'workGroupMappingXCCGroup',
'globalSplitUCoalesced',
'globalSplitUWorkGroupMappingRoundRobin'
'globalSplitUWorkGroupMappingRoundRobin',
'ldsBuffer',
'localReadVectorWidth',
'numLoadsCoalescedA',
'storeVectorWidth',
'vectorWidthA'
]

@classmethod
Expand Down Expand Up @@ -586,7 +591,12 @@ def FromOriginalState(cls, d):
workGroupMappingXCC = d['WorkGroupMappingXCC'],
workGroupMappingXCCGroup = d['WorkGroupMappingXCCGroup'],
globalSplitUCoalesced = d['GlobalSplitUCoalesced'],
globalSplitUWorkGroupMappingRoundRobin = d['GlobalSplitUWorkGroupMappingRoundRobin']
globalSplitUWorkGroupMappingRoundRobin = d['GlobalSplitUWorkGroupMappingRoundRobin'],
ldsBuffer = d['1LDSBuffer'],
localReadVectorWidth = d['LocalReadVectorWidth'],
numLoadsCoalescedA = d['NumLoadsCoalescedA'],
storeVectorWidth = d['StoreVectorWidth'],
vectorWidthA = d['VectorWidthA']
)

@classmethod
Expand Down
46 changes: 46 additions & 0 deletions tensilelite/Tensile/SolutionLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ def __init__(self, features, trees, nullValue):
self.trees = trees
self.nullValue = nullValue

class RegressionTreeLibrary:
Tag = "RegressionTree"
StateKeys = [("type", "tag"), "table", "trees", "solutionFeatures", "problemFeatures"]

@classmethod
def FromOriginalState(cls, d, solutions):
origTable = d["table"]
table = []

try:
indexStart = origTable[0]
indexOffset = origTable[1]
for index in range(indexStart, indexStart + indexOffset):
value = IndexSolutionLibrary(solutions[index])
table.append(value)
except KeyError:
pass

trees = d["trees"]
solution_features = d["solutionFeatures"]
problem_features = d["problemFeatures"]
return cls(table, trees, solution_features, problem_features)

@property
def tag(self):
return self.__class__.Tag

def merge(self, other):
raise RuntimeError(
"RegressionTreeLibrary does not support merging."
)

def remapSolutionIndices(self, indexMap):
pass

def __init__(self, table, trees, solution_features, problem_features):
self.table = table
self.trees = trees
self.solutionFeatures = solution_features
self.problemFeatures = problem_features

class ProblemMapLibrary:
Tag = "ProblemMap"
Expand Down Expand Up @@ -393,6 +433,12 @@ def selection(d, problemType, solutions, library, placeholderName):

treeLib = DecisionTreeLibrary.FromOriginalState(lib, solutions)
library.rows.append({"predicate": predicate, "library": treeLib})
elif d["LibraryType"] == "RegressionTree":
predicate = Properties.Predicate(tag="TruePred")

regressionLib = RegressionTreeLibrary.FromOriginalState(d["Library"], solutions)
library = PredicateLibrary(tag="Problem")
library.rows.append({"predicate": predicate, "library": regressionLib})
else:
assert 0 and "Unrecognized LibraryType."

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,9 @@ namespace TensileLite
case ExperimentalOption::StreamK:
return PerformanceMetric::ExperimentalStreamK;

case ExperimentalOption::RTree:
return PerformanceMetric::ExperimentalRTree;

default:
// warning?
return m_performanceMetric;
Expand Down Expand Up @@ -1118,6 +1121,8 @@ namespace TensileLite
return m_arithmeticIntensity;
}

float getLog10Flops() const;

virtual std::vector<ConstantDescriptor> const constants() const
{
std::vector<ConstantDescriptor> c = {{"alpha", m_alphaType}, {"beta", m_betaType}};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,40 @@ namespace TensileLite
}
};

struct ExperimentalRTree
: public Predicate_CRTP<ExperimentalRTree, ContractionProblemGemm>
{
enum
{
HasIndex = false,
HasValue = false
};

ExperimentalRTree() = default;

static std::string Type()
{
return "ExperimentalRTree";
}

virtual bool operator()(ContractionProblemGemm const& problem) const override
{
return (problem.performanceMetric() == PerformanceMetric::ExperimentalRTree);
}

virtual bool debugEval(ContractionProblemGemm const& problem,
std::ostream& stream) const override
{
return debugEvalCmp(problem,
stream,
"prob",
problem.performanceMetric(),
"==",
"sol: PerformanceMetric::ExperimentalRTree",
PerformanceMetric::ExperimentalRTree);
}
};

struct EqualityMatching
: public Predicate_CRTP<EqualityMatching, ContractionProblemGemm>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ namespace TensileLite
int workGroupMappingXCCGroup = 0;
bool globalSplitUCoalesced = false;
bool globalSplitUWorkGroupMappingRoundRobin = false;

int ldsBuffer = 0;
size_t localReadVectorWidth = 1;
size_t numLoadsCoalescedA = 1;
size_t storeVectorWidth = 1;
size_t vectorWidthA = 1;
};

/**
Expand Down
Loading