Skip to content

Commit

Permalink
DirectToLds for hgemm configurations (ROCm#1536)
Browse files Browse the repository at this point in the history
Update checks for DirectToLds with hgemm and add test cases
  • Loading branch information
AlexBrownAMD authored Jun 13, 2022
1 parent be83117 commit c7d39b6
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 10 deletions.
20 changes: 10 additions & 10 deletions Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,9 +2435,16 @@ def isDirectToLdsDoable(state, tc):
#TN
# use for all precisions with TransposeLDS=1

if state["ProblemType"]["DataType"].isHalf() and state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0:
reject(state, "can't use DirectToLds for FP16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"])
return False
if state["ProblemType"]["DataType"].isHalf():
if state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0:
reject(state, "can't use DirectToLds for FP16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"])
return False
if state["ProblemType"]["TransposeA"] != True or state["ProblemType"]["TransposeB"] != False:
reject(state, "DirectToLds for FP16 currently only working for TN")
return False
if state["GlobalReadVectorWidth"] < 4:
reject(state, "GlobalReadVectorWidth must be 4 for DirectToLds HGEMM")
return False

if state["ProblemType"]["DataType"].isBFloat16() and state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0:
reject(state, "can't use DirectToLds for BF16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"])
Expand All @@ -2447,13 +2454,6 @@ def isDirectToLdsDoable(state, tc):
reject(state, "can't use DirectToLds for NumThreads % WavefrontSize != 0")
return False

# GLVW*BPe only for precision(s) < 4 (bpe)
#if (state["ProblemType"]["TLU%c"%tc] == True and numBytes < 4):
if (numBytes < 4):
if state["GlobalLoadVectorWidth%c"%tc] * numBytes != 4:
reject(state, "can't use DirectToLds for bpe < 4 and GlobalLoadVectorWidth * numBytes != 4"%tc)
return False

if state["ProblemType"]["TLU%c"%tc] == state["UnrollMajorLDS%c" % tc]:
reject(state, "can't use DirectToLds for TLU%c == UnrollMajorLDS%c"%(tc, tc))
return False
Expand Down
129 changes: 129 additions & 0 deletions Tensile/Tests/extended/direct_to_lds/dtl_hgemm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
TestParameters:
marks: [skip-gfx900, skip-gfx906, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030] # not supported by arch

GlobalParameters:
NumElementsToValidate: -1
BoundsCheck: True
KernelTime: True
# PrintSolutionRejectionReason: True

BenchmarkProblems:
########################################
# TN
########################################
- # hgemm TN
- # ProblemType
OperationType: GEMM
DataType: h
ComputeDataType: s
HighPrecisionAccumulate: True
TransposeA: True
TransposeB: False
UseBeta: True
Batched: True

- # MFMA 16x16, VW = 4
InitialSolutionParameters:
BenchmarkCommonParameters:
- KernelLanguage: ["Assembly"]
- EdgeType: ["ShiftPtr"]
- AtomicAddC: [False]
ForkParameters:
- MatrixInstruction:
- [16, 16, 16, 1, 1, 4,4, 2,2]
- [16, 16, 16, 1, 1, 4,4, 1,4]
- [16, 16, 16, 1, 1, 4,4, 4,1]
- [16, 16, 4, 4, 1, 4,4, 2,2]
- [16, 16, 4, 4, 2, 4,4, 2,2]
- [16, 16, 4, 4, 4, 4,4, 2,2]
# - [4, 4, 4, 16, 1, 4,4, 2,2] # Error
- [4, 4, 4, 16, 2, 4,4, 2,2]
- [4, 4, 4, 16, 4, 4,4, 2,2]
- [4, 4, 4, 16, 8, 4,4, 2,2]
# - [4, 4, 4, 16, 16, 4,4, 2,2] # Error
- ThreadTile:
- [ 8, 32 ]
- WorkGroup:
- [ 16, 16, 1 ]
- SourceSwap: [True]
- PrefetchGlobalRead: [1]
- AssertFree0ElementMultiple: [2]
# - AssertFree1ElementMultiple: [2]
- WorkGroupMapping: [8]
- PrefetchLocalRead: [1, 2, 3]
- GlobalSplitU: [1]
- DepthU: [16, 32, 64]
- StoreVectorWidth: [4]
- VectorWidth: [4]
- GlobalReadVectorWidth: [4] # Error with GRVW < 4
- LocalReadVectorWidth: [4]
- DirectToLds: [True]
- DirectToVgprA: [False] # DirectToVgpr only for dgemm right now
- DirectToVgprB: [False] # DirectToVgpr only for dgemm right now
- WaveSeparateGlobalReadA: [0, 1]
- WaveSeparateGlobalReadB: [0, 1]
- NumLoadsCoalescedA: [1] # NLC=2 not working
- NumLoadsCoalescedB: [1] # NLC=2 not working
- ScheduleIterAlg: [3]
- AssertSummationElementMultiple: [32]
- StaggerU: [0]
- NumElementsPerBatchStore: [0] # minimum 2 for SCIU
- FractionalLoad: [0, 1, 2]
- BufferLoad: [True] # DirectToLds requires BufferLoad
- TransposeLDS: [1]
BenchmarkForkParameters:
JoinParameters:
BenchmarkJoinParameters:
BenchmarkFinalParameters:
- ProblemSizes:
- Exact: [1024, 1024, 1, 1024]

- # MFMA 32x32, VW = 2
InitialSolutionParameters:
BenchmarkCommonParameters:
- KernelLanguage: ["Assembly"]
- EdgeType: ["ShiftPtr"]
- AtomicAddC: [False]
ForkParameters:
- MatrixInstruction:
- [32, 32, 4, 2, 1, 2,2, 2,2]
- [32, 32, 4, 2, 2, 2,2, 2,2]
- [32, 32, 8, 1, 1, 2,2, 2,2]
- [32, 32, 8, 1, 1, 2,2, 1,4]
- [32, 32, 8, 1, 1, 2,2, 4,1]
- ThreadTile:
- [ 8, 32 ]
- WorkGroup:
- [ 16, 16, 1 ]
- SourceSwap: [True]
- PrefetchGlobalRead: [1]
- AssertFree0ElementMultiple: [2]
# - AssertFree1ElementMultiple: [2]
- WorkGroupMapping: [8]
- PrefetchLocalRead: [1, 2, 3]
- GlobalSplitU: [1]
- DepthU: [16, 32, 64]
- StoreVectorWidth: [2]
- VectorWidth: [2]
- GlobalReadVectorWidth: [4] # Error with GRVW < 4
- LocalReadVectorWidth: [4]
- DirectToLds: [True]
- DirectToVgprA: [False] # DirectToVgpr only for dgemm right now
- DirectToVgprB: [False] # DirectToVgpr only for dgemm right now
- WaveSeparateGlobalReadA: [0, 1]
- WaveSeparateGlobalReadB: [0, 1]
- NumLoadsCoalescedA: [1] # NLC=2 not working
- NumLoadsCoalescedB: [1] # NLC=2 not working
- ScheduleIterAlg: [3]
- AssertSummationElementMultiple: [32]
- StaggerU: [0]
- NumElementsPerBatchStore: [0] # minimum 2 for SCIU
- FractionalLoad: [0, 1, 2]
- BufferLoad: [True] # DirectToLds requires BufferLoad
- TransposeLDS: [1]
BenchmarkForkParameters:
JoinParameters:
BenchmarkJoinParameters:
BenchmarkFinalParameters:
- ProblemSizes:
- Exact: [1024, 1024, 1, 1024]

0 comments on commit c7d39b6

Please sign in to comment.