From e3535bbf25518835fa17292db1cc27a1325fd5ee Mon Sep 17 00:00:00 2001 From: Yen Ong Date: Fri, 15 Nov 2024 12:59:43 -0600 Subject: [PATCH] Regression Tree for ranking solution selection --- tensilelite/Tensile/Contractions.py | 14 +- tensilelite/Tensile/SolutionLibrary.py | 46 ++++ .../include/Tensile/ContractionProblem.hpp | 5 + .../Tensile/ContractionProblemPredicates.hpp | 34 +++ .../include/Tensile/ContractionSolution.hpp | 6 + .../Source/lib/include/Tensile/MLFeatures.hpp | 248 ++++++++++++++++++ .../Tensile/PerformanceMetricTypes.hpp | 8 +- .../lib/include/Tensile/RegressionTree.hpp | 214 +++++++++++++++ .../include/Tensile/RegressionTreeLibrary.hpp | 226 ++++++++++++++++ .../Tensile/Serialization/Containers.hpp | 15 ++ .../Serialization/ContractionPredicates.hpp | 7 + .../Serialization/ContractionSolution.hpp | 6 + .../Tensile/Serialization/MLFeatures.hpp | 129 ++++++++- .../Serialization/RegressionTreeLibrary.hpp | 172 ++++++++++++ .../Tensile/Serialization/SolutionLibrary.hpp | 4 +- .../Source/lib/source/ContractionProblem.cpp | 21 ++ .../lib/source/PerformanceMetricTypes.cpp | 5 + 17 files changed, 1155 insertions(+), 5 deletions(-) create mode 100644 tensilelite/Tensile/Source/lib/include/Tensile/RegressionTree.hpp create mode 100644 tensilelite/Tensile/Source/lib/include/Tensile/RegressionTreeLibrary.hpp create mode 100644 tensilelite/Tensile/Source/lib/include/Tensile/Serialization/RegressionTreeLibrary.hpp diff --git a/tensilelite/Tensile/Contractions.py b/tensilelite/Tensile/Contractions.py index 99362aa3bf..bc284ca6ec 100644 --- a/tensilelite/Tensile/Contractions.py +++ b/tensilelite/Tensile/Contractions.py @@ -543,7 +543,12 @@ class SizeMapping: 'workGroupMappingXCC', 'workGroupMappingXCCGroup', 'globalSplitUCoalesced', - 'globalSplitUWorkGroupMappingRoundRobin' + 'globalSplitUWorkGroupMappingRoundRobin', + 'ldsBuffer', + 'localReadVectorWidth', + 'numLoadsCoalescedA', + 'storeVectorWidth', + 'vectorWidthA' ] @classmethod @@ -586,7 +591,12 @@ def FromOriginalState(cls, d): workGroupMappingXCC = d['WorkGroupMappingXCC'], workGroupMappingXCCGroup = d['WorkGroupMappingXCCGroup'], globalSplitUCoalesced = d['GlobalSplitUCoalesced'], - globalSplitUWorkGroupMappingRoundRobin = d['GlobalSplitUWorkGroupMappingRoundRobin'] + globalSplitUWorkGroupMappingRoundRobin = d['GlobalSplitUWorkGroupMappingRoundRobin'], + ldsBuffer = d['1LDSBuffer'], + localReadVectorWidth = d['LocalReadVectorWidth'], + numLoadsCoalescedA = d['NumLoadsCoalescedA'], + storeVectorWidth = d['StoreVectorWidth'], + vectorWidthA = d['VectorWidthA'] ) @classmethod diff --git a/tensilelite/Tensile/SolutionLibrary.py b/tensilelite/Tensile/SolutionLibrary.py index 18f500c136..1bd8424a08 100644 --- a/tensilelite/Tensile/SolutionLibrary.py +++ b/tensilelite/Tensile/SolutionLibrary.py @@ -218,6 +218,46 @@ def __init__(self, features, trees, nullValue): self.trees = trees self.nullValue = nullValue +class RegressionTreeLibrary: + Tag = "RegressionTree" + StateKeys = [("type", "tag"), "table", "trees", "solutionFeatures", "problemFeatures"] + + @classmethod + def FromOriginalState(cls, d, solutions): + origTable = d["table"] + table = [] + + try: + indexStart = origTable[0] + indexOffset = origTable[1] + for index in range(indexStart, indexStart + indexOffset): + value = IndexSolutionLibrary(solutions[index]) + table.append(value) + except KeyError: + pass + + trees = d["trees"] + solution_features = d["solutionFeatures"] + problem_features = d["problemFeatures"] + return cls(table, trees, solution_features, problem_features) + + @property + def tag(self): + return self.__class__.Tag + + def merge(self, other): + raise RuntimeError( + "RegressionTreeLibrary does not support merging." + ) + + def remapSolutionIndices(self, indexMap): + pass + + def __init__(self, table, trees, solution_features, problem_features): + self.table = table + self.trees = trees + self.solutionFeatures = solution_features + self.problemFeatures = problem_features class ProblemMapLibrary: Tag = "ProblemMap" @@ -393,6 +433,12 @@ def selection(d, problemType, solutions, library, placeholderName): treeLib = DecisionTreeLibrary.FromOriginalState(lib, solutions) library.rows.append({"predicate": predicate, "library": treeLib}) + elif d["LibraryType"] == "RegressionTree": + predicate = Properties.Predicate(tag="TruePred") + + regressionLib = RegressionTreeLibrary.FromOriginalState(d["Library"], solutions) + library = PredicateLibrary(tag="Problem") + library.rows.append({"predicate": predicate, "library": regressionLib}) else: assert 0 and "Unrecognized LibraryType." diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp index 883c04715d..b5fade2448 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblem.hpp @@ -908,6 +908,9 @@ namespace TensileLite case ExperimentalOption::StreamK: return PerformanceMetric::ExperimentalStreamK; + case ExperimentalOption::RTree: + return PerformanceMetric::ExperimentalRTree; + default: // warning? return m_performanceMetric; @@ -1118,6 +1121,8 @@ namespace TensileLite return m_arithmeticIntensity; } + float getLog10Flops() const; + virtual std::vector const constants() const { std::vector c = {{"alpha", m_alphaType}, {"beta", m_betaType}}; diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp index b18cdd2811..a139b219a3 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionProblemPredicates.hpp @@ -1900,6 +1900,40 @@ namespace TensileLite } }; + struct ExperimentalRTree + : public Predicate_CRTP + { + enum + { + HasIndex = false, + HasValue = false + }; + + ExperimentalRTree() = default; + + static std::string Type() + { + return "ExperimentalRTree"; + } + + virtual bool operator()(ContractionProblemGemm const& problem) const override + { + return (problem.performanceMetric() == PerformanceMetric::ExperimentalRTree); + } + + virtual bool debugEval(ContractionProblemGemm const& problem, + std::ostream& stream) const override + { + return debugEvalCmp(problem, + stream, + "prob", + problem.performanceMetric(), + "==", + "sol: PerformanceMetric::ExperimentalRTree", + PerformanceMetric::ExperimentalRTree); + } + }; + struct EqualityMatching : public Predicate_CRTP { diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp index a97e97f100..3fb45860af 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/ContractionSolution.hpp @@ -149,6 +149,12 @@ namespace TensileLite int workGroupMappingXCCGroup = 0; bool globalSplitUCoalesced = false; bool globalSplitUWorkGroupMappingRoundRobin = false; + + int ldsBuffer = 0; + size_t localReadVectorWidth = 1; + size_t numLoadsCoalescedA = 1; + size_t storeVectorWidth = 1; + size_t vectorWidthA = 1; }; /** diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/MLFeatures.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/MLFeatures.hpp index 7b9ee66fcc..8596a98770 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/MLFeatures.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/MLFeatures.hpp @@ -146,6 +146,26 @@ namespace TensileLite } }; + struct BatchSize : public MLFeature_CRTP + { + enum + { + HasIndex = true, + HasValue = false + }; + size_t index; + + static std::string Type() + { + return "BatchSize"; + } + + virtual float operator()(ContractionProblemGemm const& problem) const + { + return (float)problem.batchSize(index); + } + }; + struct Tile0Granularity : public MLFeature_CRTP { enum @@ -230,6 +250,234 @@ namespace TensileLite } }; + struct Log10Flops : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + + }; + + static std::string Type() + { + return "Log10Flops"; + } + + virtual float operator()(ContractionProblemGemm const& problem) const + { + return problem.getLog10Flops(); + } + }; + + struct MacroTile0 : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "MacroTile0"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.macroTile.x; + } + }; + + struct MacroTile1 : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "MacroTile1"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.macroTile.y; + } + }; + + struct DepthU : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "DepthU"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.depthU; + } + }; + + struct LdsBuffer : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "1LDSBuffer"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.ldsBuffer; + } + }; + + struct GlobalReadVWA : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "GlobalReadVectorWidthA"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.grvwA; + } + }; + + struct GlobalReadVWB : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "GlobalReadVectorWidthB"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.grvwB; + } + }; + + struct LocalReadVW : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "LocalReadVectorWidth"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.localReadVectorWidth; + } + }; + + struct NumLoadsCoalescedA : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "NumLoadsCoalescedA"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.numLoadsCoalescedA; + } + }; + + struct StoreVW : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "StoreVectorWidth"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.storeVectorWidth; + } + }; + + struct VectorWidthA : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "VectorWidthA"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.vectorWidthA; + } + }; + + struct WorkGroupMapping : public MLFeature_CRTP + { + enum + { + HasIndex = false, + HasValue = false, + }; + + static std::string Type() + { + return "WorkGroupMapping"; + } + + virtual float operator()(ContractionSolution const& solution) const + { + return (float)solution.sizeMapping.workGroupMapping; + } + }; /** * @} */ diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/PerformanceMetricTypes.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/PerformanceMetricTypes.hpp index 6ea10a570e..0f01ec2681 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/PerformanceMetricTypes.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/PerformanceMetricTypes.hpp @@ -54,6 +54,7 @@ namespace TensileLite None = 0, DTree = 1, StreamK = 2, + RTree = 3, Count }; @@ -67,6 +68,7 @@ namespace TensileLite DeviceEfficiency, ExperimentalDTree, ExperimentalStreamK, + ExperimentalRTree, Count }; @@ -152,7 +154,11 @@ namespace TensileLite : public BasePerformanceMetricInfo { }; - + template <> + struct PerformanceMetricInfo + : public BasePerformanceMetricInfo + { + }; /** * @} */ diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTree.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTree.hpp new file mode 100644 index 0000000000..85951dbb35 --- /dev/null +++ b/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTree.hpp @@ -0,0 +1,214 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +namespace TensileLite +{ + /** + * \ingroup Tensile + * \defgroup RegressionTree Regression Tree + * + * @brief Regression model to estimate efficiencies + * + * Group of trees used to estimate efficiency values for solutions in the + * library. Used for RegressionTreeLibrary. + */ + + /** + * \ingroup RegressionTree + */ + namespace RegressionTree + { + struct Node + { + int featureIdx; // Index into feature array + float threshold; // Decision threshold value + float nextIdxLTE; // Next node index if val <= threshold, may be leaf value + float nextIdxGT; // Next node index if val > threshold, may be leaf value + }; + + /** + * @brief Generic regression tree + * + * @tparam Key type used for deciding + * @tparam ReturnValue type returned by tree + */ + template + struct Tree + { + + Tree() = default; + Tree(std::vector tree) + : tree(std::move(tree)) + { + } + + float predict(Key const& key) const + { + int nodeIdx = 0; + int treeSize = tree.size(); + Node currentNode; + float treeValue; + + while(nodeIdx < treeSize) + { + currentNode = tree[nodeIdx]; + bool branchLTE = key[currentNode.featureIdx] <= currentNode.threshold; + treeValue = branchLTE ? currentNode.nextIdxLTE : currentNode.nextIdxGT; + + if(treeValue < 1.0f) + return treeValue; + else + nodeIdx = static_cast(treeValue); + } + + throw std::runtime_error("Regression Tree out of bounds error."); + return false; + } + + bool valid(bool verbose = false) const + { + size_t treeSize = tree.size(); + Node currentNode; + bool valid = true; + + if(treeSize == 0) + { + if(verbose) + { + std::cout << "Tree invalid: no nodes." << std::endl; + } + return false; + } + + if(treeSize > ((size_t)std::numeric_limits::max() + 1)) + { + /* Restrict size to +ve int range, -ve idxs for reserved values */ + if(verbose) + { + std::cout << "Tree invalid: too many nodes." << std::endl; + } + return false; + } + + // Check for any invalid nodes + for(int nodeIdx = 0; nodeIdx < treeSize; nodeIdx++) + { + currentNode = tree[nodeIdx]; + + // Avoid OOB on feature array + if((currentNode.featureIdx < 0) + || (currentNode.featureIdx >= std::tuple_size::value)) + { + if(verbose) + { + std::cout << "Node " << std::to_string(nodeIdx) + << " invalid: Unrecognised type '" + << std::to_string(currentNode.featureIdx) << "'" << std::endl; + } + valid = false; + } + } + return valid; + } + + std::vector tree; + }; + + /** + * @brief Abstract base class for a group of regression trees + * + * @tparam Key used to query trees + * @tparam ReturnValue type returned by trees + */ + template + struct Forest + { + Forest() = default; + + virtual ~Forest() = default; + + virtual ReturnValue computeEfficiency(Key const& inputParameters) const = 0; + + virtual std::string description() const = 0; + }; + + /** + * @brief Forest that returns value from trees + * @tparam Key used to query trees + * @tparam ReturnValue type returned by trees + */ + template + struct BasicForest : public Forest + { + using Base = Forest; + using Tree = Tree; + + BasicForest() {} + + virtual ReturnValue computeEfficiency(Key const& inputParameters) const override + { + bool debug = Debug::Instance().getSolutionSelectionTrace(); + if(debug) + { + std::cout << "Forest " << this->description() << std::endl; + std::cout << "Entering solution selection evaluation loop. Searching forest." + << std::endl; + } + + ReturnValue rv = 0; + size_t treenum = 0; + for(Tree const& tree : trees) + { + + if(debug) + std::cout << "Running predict tree: " << treenum << std::endl; + + ReturnValue result = tree.predict(inputParameters); + rv += result; + if(debug) + { + std::cout << "Tree " << treenum << " predicts: " << result << std::endl; + std::cout << "Accummulated value: " << rv << std::endl; + } + } + return rv; + } + + virtual std::string description() const override + { + return concatenate("RegressionTree Forest: Number of trees ", trees.size()); + } + + std::vector trees; + }; + } // namespace RegressionTree +} // namespace TensileLite diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTreeLibrary.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTreeLibrary.hpp new file mode 100644 index 0000000000..a8bea1fe30 --- /dev/null +++ b/tensilelite/Tensile/Source/lib/include/Tensile/RegressionTreeLibrary.hpp @@ -0,0 +1,226 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace TensileLite +{ + /** + * \ingroup SolutionLibrary + * + * Uses a set of regression trees to rank solutions for a given size. + */ + + template + struct RegressionTreeLibrary : public SolutionLibrary + { + using Forest = RegressionTree::Forest, float>; + using SolutionFeatures = std::vector>>; + using ProblemFeatures = std::vector>>; + + std::map> solutionmap; + std::shared_ptr forest; + SolutionFeatures solFeatures; + ProblemFeatures probFeatures; + + static std::string Type() + { + return "RegressionTree"; + } + virtual std::string type() const override + { + return Type(); + } + virtual std::string description() const override + { + if(forest == nullptr) + return concatenate(type(), ", forest: nullptr"); + else + return concatenate(type(), ": ", forest->description()); + } + + virtual std::shared_ptr getSolutionByIndex(MyProblem const& problem, + Hardware const& hardware, + const int index) const override + { + const bool experimental = Debug::Instance().useExperimentalSelection(); + if(!experimental) + { + // If the experimental library mode is not on treat it like it asserted out + return nullptr; + } + // ; + auto indexMatch = solutionmap.find(index); + if(indexMatch != solutionmap.end()) + return indexMatch->second; + return nullptr; + } + + virtual std::shared_ptr findBestSolution(MyProblem const& problem, + Hardware const& hardware, + double* fitness + = nullptr) const override + { + std::vector problemkey + = ProblemKey::keyForProblem, MyProblem, float>( + problem, this->probFeatures); + + float result = 0.0; + float bestEfficiency = 0.0; + std::shared_ptr bestMatch = nullptr; + + for(auto const& row : solutionmap) + { + std::vector solutionkey + = ProblemKey::keyForProblem, MySolution, float>( + *row.second, this->solFeatures); + std::vector key; + key.reserve(solutionkey.size() + problemkey.size()); + key.insert(key.end(), solutionkey.begin(), solutionkey.end()); + key.insert(key.end(), problemkey.begin(), problemkey.end()); + result = forest->computeEfficiency(key); + + if(result > bestEfficiency) + { + bestEfficiency = result; + bestMatch = row.second; + } + } + return bestMatch; + } + + virtual SolutionSet + findAllSolutions(MyProblem const& problem, + Hardware const& hardware, + SolutionLibrarySearchType searchType + = SolutionLibrarySearchType::DEFAULT) const override + { + if(searchType != SolutionLibrarySearchType::DEFAULT) + { + // if the solution library search is not default then return an empty + // set of solutions. + SolutionSet rv; + return rv; + } + + const bool experimental = Debug::Instance().useExperimentalSelection(); + if(!experimental) + { + // Skip the search for solutions if the environment variable + // that enables the experimental method is not set + SolutionSet rv; + return rv; + } + SolutionSet rv; + for(auto const& row : solutionmap) + rv.insert(row.second); + + return rv; + } + + virtual SolutionVector findTopSolutions(MyProblem const& problem, + Hardware const& hardware, + int numSolutions) const override + { + // TODO gather topN in sorted map + std::vector problemkey + = ProblemKey::keyForProblem, MyProblem, float>( + problem, this->probFeatures); + + std::map> solutionRank; + for(auto const& row : solutionmap) + { + std::vector solutionkey + = ProblemKey::keyForProblem, MySolution, float>( + *row.second, this->solFeatures); + + std::vector key; + key.reserve(solutionkey.size() + problemkey.size()); + key.insert(key.end(), solutionkey.begin(), solutionkey.end()); + key.insert(key.end(), problemkey.begin(), problemkey.end()); + float result = forest->computeEfficiency(key); + + if(solutionRank.size() < numSolutions) + solutionRank.insert(std::make_pair(result, row.first)); + else if(solutionRank.rbegin()->first < result) + { + auto minIter = solutionRank.rbegin(); + solutionRank.erase(minIter->first); + solutionRank.insert(std::make_pair(result, row.first)); + } + } + SolutionVector rv; + for(auto const& row : solutionRank) + { + auto indexMatch = solutionmap.find(row.second); + rv.push_back(indexMatch->second); + } + return rv; + } + + virtual SolutionSet + findAllSolutionsGroupedGemm(std::vector const& problems, + Hardware const& hardware, + SolutionLibrarySearchType searchType + = SolutionLibrarySearchType::DEFAULT) const override + { + if(searchType != SolutionLibrarySearchType::DEFAULT) + { + // if the solution library search is notSolutionSet default then return an empty + // set of solutions + SolutionSet rv; + return rv; + } + + const bool experimental = Debug::Instance().useExperimentalSelection(); + if(!experimental) + { + // Skip the search for solutions if the environment variable + // that enables the experimental method is not set + SolutionSet rv; + return rv; + } + + SolutionSet rv; + for(auto const& row : solutionmap) + rv.insert(row.second); + + return rv; + } + }; + +} // namespace TensileLite diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/Containers.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/Containers.hpp index 90b712060d..95d1e95e0c 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/Containers.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/Containers.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -191,6 +192,9 @@ namespace TensileLite TENSILE_SERIALIZE_VECTOR(false, std::shared_ptr); + TENSILE_SERIALIZE_VECTOR( + false, std::shared_ptr>); + template struct SequenceTraits>, IO> : public DefaultSequenceTraits>, IO, false> @@ -220,6 +224,17 @@ namespace TensileLite { }; + TENSILE_SERIALIZE_VECTOR(true, TensileLite::RegressionTree::Node); + template + struct SequenceTraits>, IO> + : public DefaultSequenceTraits< + std::vector>, + IO, + false> + { + }; + + template struct SequenceTraits, IO> { diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionPredicates.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionPredicates.hpp index b8675c354d..2b008b45fd 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionPredicates.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionPredicates.hpp @@ -99,6 +99,7 @@ namespace TensileLite Base::template Pair(), Base::template Pair(), Base::template Pair(), + Base::template Pair(), Base::template Pair(), Base::template Pair(), Base::template Pair(), @@ -394,6 +395,12 @@ namespace TensileLite { }; + template + struct MappingTraits + : public AutoMappingTraits + { + }; + template struct MappingTraits : public AutoMappingTraits diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp index fdd7ab5d45..1426bb5e6e 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/ContractionSolution.hpp @@ -121,6 +121,12 @@ namespace TensileLite iot::mapRequired(io, "globalSplitUWorkGroupMappingRoundRobin", s.globalSplitUWorkGroupMappingRoundRobin); + + iot::mapRequired(io, "ldsBuffer", s.ldsBuffer); + iot::mapRequired(io, "localReadVectorWidth", s.localReadVectorWidth); + iot::mapRequired(io, "numLoadsCoalescedA", s.numLoadsCoalescedA); + iot::mapRequired(io, "storeVectorWidth", s.storeVectorWidth); + iot::mapRequired(io, "vectorWidthA", s.vectorWidthA); } const static bool flow = false; diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/MLFeatures.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/MLFeatures.hpp index d622e50dae..e612e4c342 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/MLFeatures.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/MLFeatures.hpp @@ -93,10 +93,50 @@ namespace TensileLite return SubclassMap({Base::template Pair(), Base::template Pair(), Base::template Pair(), + Base::template Pair(), Base::template Pair(), Base::template Pair(), Base::template Pair(), - Base::template Pair()}); + Base::template Pair(), + Base::template Pair()}); + } + }; + + // Set Flow + template + struct MappingTraits>, IO> + : public BaseClassMappingTraits, IO, true> + { + }; + + template + struct SubclassMappingTraits, IO> + : public DefaultSubclassMappingTraits< + SubclassMappingTraits, IO>, + MLFeatures::MLFeature, + IO> + { + using Self = SubclassMappingTraits, IO>; + using Base = DefaultSubclassMappingTraits< + SubclassMappingTraits, IO>, + MLFeatures::MLFeature, + IO>; + using SubclassMap = typename Base::SubclassMap; + const static SubclassMap subclasses; + + static typename Base::SubclassMap GetSubclasses() + { + return SubclassMap({Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair(), + Base::template Pair()}); } }; @@ -109,6 +149,15 @@ namespace TensileLite ContractionProblemFeatureSMT::subclasses = ContractionProblemFeatureSMT::GetSubclasses(); + template + using ContractionSolutionFeatureSMT + = SubclassMappingTraits, IO>; + + template + const typename ContractionSolutionFeatureSMT::SubclassMap + ContractionSolutionFeatureSMT::subclasses + = ContractionSolutionFeatureSMT::GetSubclasses(); + template struct MappingTraits : public AutoMappingTraits @@ -127,6 +176,12 @@ namespace TensileLite { }; + template + struct MappingTraits + : public AutoMappingTraits + { + }; + template struct MappingTraits : public AutoMappingTraits @@ -150,5 +205,77 @@ namespace TensileLite : public AutoMappingTraits { }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; } // namespace Serialization } // namespace TensileLite diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/RegressionTreeLibrary.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/RegressionTreeLibrary.hpp new file mode 100644 index 0000000000..f9973a3a93 --- /dev/null +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/RegressionTreeLibrary.hpp @@ -0,0 +1,172 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +#include +#include + +namespace TensileLite +{ + namespace Serialization + { + template + struct MappingTraits, IO> + { + using Tree = RegressionTree::Tree; + using iot = IOTraits; + + static void mapping(IO& io, Tree& tree) + { + iot::mapRequired(io, "tree", tree.tree); + } + const static bool flow = false; + }; + + template + struct MappingTraits, IO> + { + using Forest = RegressionTree::BasicForest; + using iot = IOTraits; + + static void mapping(IO& io, Forest& lib) + { + int32_t index = -1; + iot::mapRequired(io, "trees", lib.trees); + } + + const static bool flow = false; + }; + + template + struct MappingTraits, IO> + { + using Library = RegressionTreeLibrary; + using iot = IOTraits; + + static void mapping(IO& io, Library& lib) + { + auto ctx = static_cast*>(iot::getContext(io)); + if(ctx == nullptr) + { + iot::setError(io, + "RegressionTreeLibrary requires that context be " + "set to a SolutionMap."); + } + std::vector mappingIndices; + if(iot::outputting(io)) + { + mappingIndices.reserve(lib.solutionmap.size()); + + for(auto const& pair : lib.solutionmap) + mappingIndices.push_back(pair.first); + + iot::mapRequired(io, "table", mappingIndices); + } + else + { + iot::mapRequired(io, "table", mappingIndices); + if(mappingIndices.empty()) + iot::setError(io, + "RegressionTreeLibrary requires non empty " + "mapping index set."); + + for(int index : mappingIndices) + { + auto slnIter = ctx->solutions->find(index); + if(slnIter == ctx->solutions->end()) + { + iot::setError( + io, + concatenate("[RegressionTreeLibrary] Invalid solution index: ", + index)); + } + else + { + auto solution = slnIter->second; + lib.solutionmap.insert(std::make_pair(index, solution)); + } + } + } + + using Forest = RegressionTree::BasicForest, float>; + std::shared_ptr forest; + if(iot::outputting(io)) + { + forest = std::dynamic_pointer_cast(lib.forest); + } + else + { + forest = std::make_shared(); + lib.forest = forest; + } + MappingTraits::mapping(io, *forest); + + using SolutionFeatures + = std::vector>>; + SolutionFeatures solFeatures; + if(iot::outputting(io)) + { + solFeatures = lib.solFeatures; + } + iot::mapOptional(io, "solutionFeatures", solFeatures); + lib.solFeatures = solFeatures; + + using ProblemFeatures + = std::vector>>; + ProblemFeatures probFeatures; + if(iot::outputting(io)) + { + probFeatures = lib.probFeatures; + } + iot::mapOptional(io, "problemFeatures", probFeatures); + lib.probFeatures = probFeatures; + } + const static bool flow = false; + }; + + template + struct MappingTraits + { + using Node = typename RegressionTree::Node; + using iot = IOTraits; + + static void mapping(IO& io, Node& node) + { + iot::mapRequired(io, "featureIdx", node.featureIdx); + iot::mapRequired(io, "threshold", node.threshold); + iot::mapRequired(io, "nextIdxLTE", node.nextIdxLTE); + iot::mapRequired(io, "nextIdxGT", node.nextIdxGT); + } + + const static bool flow = true; + }; + + } // namespace Serialization +} // namespace TensileLite diff --git a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/SolutionLibrary.hpp b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/SolutionLibrary.hpp index f46b29a1f0..c65d250bcc 100644 --- a/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/SolutionLibrary.hpp +++ b/tensilelite/Tensile/Source/lib/include/Tensile/Serialization/SolutionLibrary.hpp @@ -45,6 +45,7 @@ #include #include #include +#include namespace TensileLite { @@ -80,7 +81,8 @@ namespace TensileLite Base::template Pair>(), Base::template Pair>(), Base::template Pair>(), - Base::template Pair>()}); + Base::template Pair>(), + Base::template Pair>()}); } }; diff --git a/tensilelite/Tensile/Source/lib/source/ContractionProblem.cpp b/tensilelite/Tensile/Source/lib/source/ContractionProblem.cpp index 03fecc4c1c..b83ae799ba 100644 --- a/tensilelite/Tensile/Source/lib/source/ContractionProblem.cpp +++ b/tensilelite/Tensile/Source/lib/source/ContractionProblem.cpp @@ -1104,6 +1104,27 @@ namespace TensileLite m_arithmeticIntensity = gflop / gbyte; } + float ContractionProblemGemm::getLog10Flops() const + { + size_t problemSize = 1; + for(size_t i = 0; i < m_problemSizes.size(); ++i) + { + problemSize *= m_problemSizes[i]; + } + double flops = 2 * problemSize; + + if(m_beta != 0) // If problem includes beta, update flops + { + size_t cSize = 1; + for(size_t i = 0; i < c().dimensions(); ++i) + { + cSize *= c().sizes()[i]; + } + flops += 2 * cSize; + } + return (float)std::log10(flops); + } + size_t ContractionProblemGemm::freeSizeA(size_t idx) const { return m_freeSizesA.at(idx); diff --git a/tensilelite/Tensile/Source/lib/source/PerformanceMetricTypes.cpp b/tensilelite/Tensile/Source/lib/source/PerformanceMetricTypes.cpp index c80817d4f8..fbd4b53fb6 100644 --- a/tensilelite/Tensile/Source/lib/source/PerformanceMetricTypes.cpp +++ b/tensilelite/Tensile/Source/lib/source/PerformanceMetricTypes.cpp @@ -45,6 +45,8 @@ namespace TensileLite return "ExperimentalDTree"; case PerformanceMetric::ExperimentalStreamK: return "ExperimentalStreamK"; + case PerformanceMetric::ExperimentalRTree: + return "ExperimentalRTree"; case PerformanceMetric::Count: default:; @@ -66,6 +68,8 @@ namespace TensileLite return "DTree"; case PerformanceMetric::ExperimentalStreamK: return "StreamK"; + case PerformanceMetric::ExperimentalRTree: + return "RTree"; case PerformanceMetric::Count: default:; @@ -106,6 +110,7 @@ namespace TensileLite registerTypeInfo(); registerTypeInfo(); registerTypeInfo(); + registerTypeInfo(); } void PerformanceMetricTypeInfo::registerAllTypeInfoOnce()