Skip to content

Commit

Permalink
Allow distinct K0/K1 values for A/B block descriptor (#98)
Browse files Browse the repository at this point in the history
* add gitignore

* host tensor: allow generating sequentially increasing value in a given dimension

* gridwise gemm v3r1: allow distinct K0/K1 values for A/B block descriptor

- remove dangling header include
- modify example gemm_xdl accordingly
- infer KPack value from M/NPerXdl
- device conv2d fwd: update parameters accordingly for the underlying gridwise gemm v3r1
(API for conv2d fwd stays the same for now until we decide to expose individual K0s for activation and weight)

* add LDS data dump utility

* profiler: reflect API change for distinct K0/K1 for A/B matrices

* profiler: add conflict-free LDS write FP16 kernel instances

* fix accidental perf regression

* address feedback; cosmetic changes

* clang-format for new files

* format

Co-authored-by: Chao Liu <[email protected]>
  • Loading branch information
rosenrodt and Chao Liu authored Feb 28, 2022
1 parent e221d11 commit 6d4450e
Show file tree
Hide file tree
Showing 14 changed files with 431 additions and 268 deletions.
48 changes: 48 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Compiled Object files
*.slo
*.lo
*.o
*.obj

# Precompiled Headers
*.gch
*.pch
*.ipch

# Compiled Dynamic libraries
*.so
*.dylib
*.dll

# Fortran module files
*.mod

# Compiled Static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

# vim tags
tags
.tags
.*.swp

# Editors
.vscode

# build-in-source directory
build*

# emacs temporary/backup files
.\#*
\#*\#
*~

# GDB temporary files
.gdb_history
109 changes: 54 additions & 55 deletions composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t K1>
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
Expand All @@ -29,10 +29,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1

static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);

static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);

static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};

static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
Expand Down Expand Up @@ -66,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1

const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();

return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
return make_tuple(0, waveId_m, xdlops_a_idx[I1], Number<KPack>{} * xdlops_a_idx[I0]);
}

__device__ static auto CalculateBThreadOriginDataIndex()
Expand All @@ -77,7 +82,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1

const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();

return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
return make_tuple(0, waveId_n, xdlops_b_idx[I1], Number<KPack>{} * xdlops_b_idx[I0]);
}

template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
Expand Down Expand Up @@ -115,12 +120,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");

static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K0 dimension not consistent");

static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
"wrong! K1 dimension not consistent");

static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");

Expand Down Expand Up @@ -219,32 +218,32 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}

__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}

__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}

static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();

template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
Expand All @@ -258,31 +257,31 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1

static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0),
a_thread_buf);

static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0),
b_thread_buf);

static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
vector_type<FloatAB, K1> a_thread_vec;
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;

static_for<0, K1, 1>{}([&](auto i) {
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});

using mfma_input_type =
Expand All @@ -301,37 +300,37 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}

private:
// A[K0, M0, M1, M2, K1]
// A[M0, M1, M2, KPerBlock]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));

// B[K0, N0, N1, N2, K1]
// B[N0, N1, N2, KPerBlock]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));

// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));

using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
K1>;
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;

using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
K1>;
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;

AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
Expand Down
Loading

0 comments on commit 6d4450e

Please sign in to comment.