diff --git a/Tensile/SolutionStructs.py b/Tensile/SolutionStructs.py index 98db55d66..62b5fffbe 100644 --- a/Tensile/SolutionStructs.py +++ b/Tensile/SolutionStructs.py @@ -2435,9 +2435,16 @@ def isDirectToLdsDoable(state, tc): #TN # use for all precisions with TransposeLDS=1 - if state["ProblemType"]["DataType"].isHalf() and state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0: - reject(state, "can't use DirectToLds for FP16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"]) - return False + if state["ProblemType"]["DataType"].isHalf(): + if state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0: + reject(state, "can't use DirectToLds for FP16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"]) + return False + if state["ProblemType"]["TransposeA"] != True or state["ProblemType"]["TransposeB"] != False: + reject(state, "DirectToLds for FP16 currently only working for TN") + return False + if state["GlobalReadVectorWidth"] < 4: + reject(state, "GlobalReadVectorWidth must be 4 for DirectToLds HGEMM") + return False if state["ProblemType"]["DataType"].isBFloat16() and state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0: reject(state, "can't use DirectToLds for BF16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"]) @@ -2447,13 +2454,6 @@ def isDirectToLdsDoable(state, tc): reject(state, "can't use DirectToLds for NumThreads % WavefrontSize != 0") return False - # GLVW*BPe only for precision(s) < 4 (bpe) - #if (state["ProblemType"]["TLU%c"%tc] == True and numBytes < 4): - if (numBytes < 4): - if state["GlobalLoadVectorWidth%c"%tc] * numBytes != 4: - reject(state, "can't use DirectToLds for bpe < 4 and GlobalLoadVectorWidth * numBytes != 4"%tc) - return False - if state["ProblemType"]["TLU%c"%tc] == state["UnrollMajorLDS%c" % tc]: reject(state, "can't use DirectToLds for TLU%c == UnrollMajorLDS%c"%(tc, tc)) return False diff --git a/Tensile/Tests/extended/direct_to_lds/dtl_hgemm.yaml b/Tensile/Tests/extended/direct_to_lds/dtl_hgemm.yaml new file mode 100644 index 000000000..cf2902b72 --- /dev/null +++ b/Tensile/Tests/extended/direct_to_lds/dtl_hgemm.yaml @@ -0,0 +1,129 @@ +TestParameters: + marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030] # not supported by arch + +GlobalParameters: + NumElementsToValidate: -1 + BoundsCheck: True + KernelTime: True + # PrintSolutionRejectionReason: True + +BenchmarkProblems: + ######################################## + # TN + ######################################## + - # hgemm TN + - # ProblemType + OperationType: GEMM + DataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: True + TransposeB: False + UseBeta: True + Batched: True + + - # MFMA 16x16, VW = 4 + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - AtomicAddC: [False] + ForkParameters: + - MatrixInstruction: + - [16, 16, 16, 1, 1, 4,4, 2,2] + - [16, 16, 16, 1, 1, 4,4, 1,4] + - [16, 16, 16, 1, 1, 4,4, 4,1] + - [16, 16, 4, 4, 1, 4,4, 2,2] + - [16, 16, 4, 4, 2, 4,4, 2,2] + - [16, 16, 4, 4, 4, 4,4, 2,2] + # - [4, 4, 4, 16, 1, 4,4, 2,2] # Error + - [4, 4, 4, 16, 2, 4,4, 2,2] + - [4, 4, 4, 16, 4, 4,4, 2,2] + - [4, 4, 4, 16, 8, 4,4, 2,2] + # - [4, 4, 4, 16, 16, 4,4, 2,2] # Error + - ThreadTile: + - [ 8, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + - SourceSwap: [True] + - PrefetchGlobalRead: [1] + - AssertFree0ElementMultiple: [2] + # - AssertFree1ElementMultiple: [2] + - WorkGroupMapping: [8] + - PrefetchLocalRead: [1, 2, 3] + - GlobalSplitU: [1] + - DepthU: [16, 32, 64] + - StoreVectorWidth: [4] + - VectorWidth: [4] + - GlobalReadVectorWidth: [4] # Error with GRVW < 4 + - LocalReadVectorWidth: [4] + - DirectToLds: [True] + - DirectToVgprA: [False] # DirectToVgpr only for dgemm right now + - DirectToVgprB: [False] # DirectToVgpr only for dgemm right now + - WaveSeparateGlobalReadA: [0, 1] + - WaveSeparateGlobalReadB: [0, 1] + - NumLoadsCoalescedA: [1] # NLC=2 not working + - NumLoadsCoalescedB: [1] # NLC=2 not working + - ScheduleIterAlg: [3] + - AssertSummationElementMultiple: [32] + - StaggerU: [0] + - NumElementsPerBatchStore: [0] # minimum 2 for SCIU + - FractionalLoad: [0, 1, 2] + - BufferLoad: [True] # DirectToLds requires BufferLoad + - TransposeLDS: [1] + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1024, 1024, 1, 1024] + + - # MFMA 32x32, VW = 2 + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + - EdgeType: ["ShiftPtr"] + - AtomicAddC: [False] + ForkParameters: + - MatrixInstruction: + - [32, 32, 4, 2, 1, 2,2, 2,2] + - [32, 32, 4, 2, 2, 2,2, 2,2] + - [32, 32, 8, 1, 1, 2,2, 2,2] + - [32, 32, 8, 1, 1, 2,2, 1,4] + - [32, 32, 8, 1, 1, 2,2, 4,1] + - ThreadTile: + - [ 8, 32 ] + - WorkGroup: + - [ 16, 16, 1 ] + - SourceSwap: [True] + - PrefetchGlobalRead: [1] + - AssertFree0ElementMultiple: [2] + # - AssertFree1ElementMultiple: [2] + - WorkGroupMapping: [8] + - PrefetchLocalRead: [1, 2, 3] + - GlobalSplitU: [1] + - DepthU: [16, 32, 64] + - StoreVectorWidth: [2] + - VectorWidth: [2] + - GlobalReadVectorWidth: [4] # Error with GRVW < 4 + - LocalReadVectorWidth: [4] + - DirectToLds: [True] + - DirectToVgprA: [False] # DirectToVgpr only for dgemm right now + - DirectToVgprB: [False] # DirectToVgpr only for dgemm right now + - WaveSeparateGlobalReadA: [0, 1] + - WaveSeparateGlobalReadB: [0, 1] + - NumLoadsCoalescedA: [1] # NLC=2 not working + - NumLoadsCoalescedB: [1] # NLC=2 not working + - ScheduleIterAlg: [3] + - AssertSummationElementMultiple: [32] + - StaggerU: [0] + - NumElementsPerBatchStore: [0] # minimum 2 for SCIU + - FractionalLoad: [0, 1, 2] + - BufferLoad: [True] # DirectToLds requires BufferLoad + - TransposeLDS: [1] + BenchmarkForkParameters: + JoinParameters: + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1024, 1024, 1, 1024]