Skip to content

Commit

Permalink
KernelTuningNet for CK [Update CK Fwd, CK Wrw + New integration CK Bw…
Browse files Browse the repository at this point in the history
…d] (#3464)

* ck fwd, wrw, bwd KTN updated

* solvers merge

* format, convert cout to logs

* hip tidy

---------

Co-authored-by: Christopher Erb <[email protected]>
Co-authored-by: BrianHarrisonAMD <[email protected]>
  • Loading branch information
3 people authored Jan 13, 2025
1 parent d94ef1f commit 200bc5d
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 158 deletions.
15 changes: 12 additions & 3 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,11 @@ bool ModelSetParams(const std::string& arch,
{
auto model = GetModel(arch, solver);

std::stringstream ss;
for(int i = 0; i < 17; ++i)
ss << features[i * 17 + i] << ", ";
MIOPEN_LOG_I2("Features: " << ss.str());

// get context
int dim = 0;
if(transform_features)
Expand All @@ -727,10 +732,11 @@ bool ModelSetParams(const std::string& arch,
default: return false;
}

MIOPEN_LOG_I2("PREDICT TYPE: " << model->metadata.predict_type);

// run decoder to set kernel parameters
for(size_t i = 0, num_tuning_params = 1; i < num_tuning_params; ++i)
{

if(i == 0 && (model->metadata.predict_type == 0u))
num_tuning_params = model->metadata.num_tuning_params[dir];

Expand All @@ -740,7 +746,9 @@ bool ModelSetParams(const std::string& arch,
// order tokens according to their scores
std::priority_queue<std::pair<float, int>> pq;
for(int j = 0; j < token_scores.size(); j++)
{
pq.push(std::make_pair(token_scores[j], j)); // sort by value at index
}

// find a token whose value is a valid kernel parameter for the i-th position
int output_token_index = -1;
Expand All @@ -751,11 +759,12 @@ bool ModelSetParams(const std::string& arch,
std::string value = model->metadata.tuning_decodings[std::to_string(token)];
pq.pop();

MIOPEN_LOG_I2(std::to_string((int)i) + ": " + std::to_string(token) + " " + value);
if(value == "-1") // if token-value is "-1", then decoding has finished
{
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
MIOPEN_LOG_I2("Model ran for " << duration.count() << " micro-seconds");
MIOPEN_LOG_I2("KTN ran for " << duration.count() << " micro-seconds. Ended at -1.");
return false;
}
if(validator(i, value)) // if token-value is a valid kernel parameter, it's set
Expand All @@ -773,7 +782,7 @@ bool ModelSetParams(const std::string& arch,

auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
MIOPEN_LOG_I2("Model ran for " << duration.count() << " micro-seconds");
MIOPEN_LOG_I2("KTN ran for " << duration.count() << " micro-seconds");
return true;
}

Expand Down
5 changes: 4 additions & 1 deletion src/include/miopen/conv/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4768,7 +4768,10 @@ struct PerformanceConfigHipImplicitGemmGroupBwdXdlops
bool RunParameterPredictionModel(const ExecutionContext& ctx,
const miopen::conv::ProblemDescription& problem);
void InitHeuristicKernelIDs();
bool ModelApplyToken(int idx, std::string value);
bool ModelApplyToken(int idx,
std::string value,
const std::string& arch,
const miopen::conv::ProblemDescription& problem);
#endif
template <typename DataType>
void Init(const miopen::conv::ProblemDescription&);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions src/kernels/gfx942_ConvHipIgemmGroupBwdXdlops_metadata.ktn.model
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"decodings": {
"tunings": {
"0": "0",
"1": "256",
"2": "128",
"3": "64",
"4": "64",
"5": "32",
"6": "128",
"7": "256",
"8": "128",
"9": "64",
"10": "32",
"11": "256",
"12": "8",
"13": "2",
"14": "Default",
"15": "Filter1x1Stride1Pad0",
"16": "1",
"17": "2",
"18": "4",
"19": "2",
"20": "1",
"21": "4",
"22": "4",
"23": "8",
"24": "1",
"25": "4",
"26": "1",
"27": "8",
"28": "-1"
}
},
"predict_type": 0,
"num_tuning_params": {
"fwd": 9,
"bwd": 9,
"wrw": 9
}
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

117 changes: 62 additions & 55 deletions src/kernels/gfx942_ConvHipIgemmGroupFwdXdlops_metadata.ktn.model
Original file line number Diff line number Diff line change
@@ -1,75 +1,82 @@
{
"predict_type": 1,
"num_tuning_params": {
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3": 17,
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle": 15
},
"decodings": {
"tunings": {
"0": "0",
"1": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"2": "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"3": "256",
"4": "128",
"5": "64",
"3": "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
"4": "256",
"5": "128",
"6": "64",
"7": "16",
"7": "64",
"8": "128",
"9": "32",
"10": "224",
"9": "16",
"10": "32",
"11": "256",
"12": "128",
"13": "32",
"14": "64",
"15": "16",
"16": "256",
"17": "224",
"18": "32",
"19": "64",
"20": "16",
"21": "128",
"22": "Default",
"23": "Filter1x1Pad0",
"12": "224",
"13": "128",
"14": "32",
"15": "64",
"16": "16",
"17": "256",
"18": "224",
"19": "16",
"20": "64",
"21": "32",
"22": "128",
"23": "Default",
"24": "OddC",
"25": "Filter1x1Stride1Pad0",
"26": "32",
"27": "16",
"25": "Filter1x1Pad0",
"26": "Filter1x1Stride1Pad0",
"27": "Filter3x3",
"28": "32",
"29": "16",
"30": "1",
"31": "2",
"32": "7",
"33": "4",
"34": "8",
"35": "2",
"36": "1",
"37": "4",
"38": "8",
"39": "7",
"40": "8",
"41": "4",
"42": "1",
"30": "32",
"31": "16",
"32": "1",
"33": "2",
"34": "4",
"35": "8",
"36": "7",
"37": "2",
"38": "1",
"39": "4",
"40": "7",
"41": "8",
"42": "4",
"43": "8",
"44": "4",
"45": "1",
"44": "1",
"45": "4",
"46": "8",
"47": "4",
"48": "1",
"49": "2",
"47": "1",
"48": "4",
"49": "8",
"50": "1",
"51": "2",
"52": "1",
"53": "2",
"54": "-1",
"55": "BlkGemmPipelineScheduler:Intrawave",
"56": "BlkGemmPipelineScheduler:Interwave",
"57": "-1",
"58": "BlkGemmPipelineVersion:v1",
"59": "BlkGemmPipelineVersion:v2",
"60": "BlkGemmPipelineVersion:v4",
"61": "BlkGemmPipelineVersion:v3",
"62": "BlkGemmPipelineVersion:v5",
"63": "-1"
"54": "1",
"55": "2",
"56": "1",
"57": "BlkGemmPipelineScheduler:Intrawave",
"58": "BlkGemmPipelineScheduler:Interwave",
"59": "8",
"60": "16",
"61": "-1",
"62": "32",
"63": "-1",
"64": "BlkGemmPipelineVersion:v3",
"65": "BlkGemmPipelineVersion:v1",
"66": "BlkGemmPipelineVersion:v2",
"67": "BlkGemmPipelineVersion:v4",
"68": "BlkGemmPipelineVersion:v5",
"69": "-1"
}
},
"predict_type": 1,
"num_tuning_params": {
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle": 15,
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3": 16,
"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor": 14
}
}
}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

99 changes: 50 additions & 49 deletions src/kernels/gfx942_ConvHipIgemmGroupWrwXdlops_metadata.ktn.model
Original file line number Diff line number Diff line change
@@ -1,86 +1,87 @@
{
"predict_type": 1,
"num_tuning_params": {
"DeviceGroupedConvBwdWeight_Xdl_CShuffle": 15,
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle": 18
},
"decodings": {
"tunings": {
"0": "0",
"1": "DeviceGroupedConvBwdWeight_Xdl_CShuffle",
"2": "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
"3": "64",
"4": "128",
"3": "128",
"4": "64",
"5": "256",
"6": "64",
"7": "128",
"8": "32",
"6": "128",
"7": "32",
"8": "64",
"9": "256",
"10": "16",
"11": "32",
"12": "128",
"13": "64",
"12": "64",
"13": "128",
"14": "256",
"15": "16",
"16": "4",
"17": "32",
"18": "Default",
"19": "Filter1x1Stride1Pad0",
"20": "8",
"21": "4",
"18": "Filter1x1Stride1Pad0",
"19": "Default",
"20": "4",
"21": "8",
"22": "2",
"23": "1",
"24": "4",
"25": "1",
"26": "2",
"27": "4",
"28": "8",
"29": "4",
"30": "1",
"31": "2",
"28": "4",
"29": "2",
"30": "8",
"31": "1",
"32": "4",
"33": "1",
"34": "2",
"33": "2",
"34": "1",
"35": "8",
"36": "8",
"37": "4",
"36": "4",
"37": "8",
"38": "1",
"39": "2",
"40": "2",
"41": "1",
"42": "4",
"40": "1",
"41": "4",
"42": "2",
"43": "8",
"44": "8",
"45": "4",
"44": "4",
"45": "8",
"46": "1",
"47": "2",
"48": "4",
"49": "8",
"50": "16",
"51": "128",
"52": "32",
"48": "8",
"49": "16",
"50": "32",
"51": "4",
"52": "128",
"53": "2",
"54": "1",
"55": "64",
"56": "BlkGemmPipelineScheduler:Intrawave",
"57": "-1",
"58": "BlkGemmPipelineVersion:v5",
"58": "BlkGemmPipelineVersion:v1",
"59": "BlkGemmPipelineVersion:v2",
"60": "-1",
"61": "8",
"62": "4",
"63": "1",
"64": "2",
"65": "-1",
"66": "32",
"67": "64",
"60": "BlkGemmPipelineVersion:v5",
"61": "-1",
"62": "1",
"63": "4",
"64": "8",
"65": "2",
"66": "-1",
"67": "128",
"68": "16",
"69": "8",
"70": "128",
"71": "2",
"72": "4",
"73": "1",
"74": "-1"
"69": "64",
"70": "32",
"71": "8",
"72": "1",
"73": "2",
"74": "4",
"75": "-1"
}
},
"predict_type": 1,
"num_tuning_params": {
"DeviceGroupedConvBwdWeight_Xdl_CShuffle": 15,
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle": 18
}
}
1 change: 0 additions & 1 deletion src/ocl/gcn_asm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ bool ValidateGcnAssembler() { return true; }
#include <cstdlib>
#include <fstream>
#include <miopen/filesystem.hpp>
#include <miopen/env.hpp>
#include <miopen/errors.hpp>
#include <miopen/manage_ptr.hpp>
#include <miopen/write_file.hpp>
Expand Down
Loading

0 comments on commit 200bc5d

Please sign in to comment.