Skip to content

Commit

Permalink
Split k f16 (#97)
Browse files Browse the repository at this point in the history
* init for splitk f16

* a working prototype

* debug

* perf debug

* update example

* instances for mk kn

* add instances for all layers

* clean

* clean

* add tuning

* format

* add mn_padding into irregular tile

* clean

Co-authored-by: Chao Liu <[email protected]>
  • Loading branch information
zjing14 and Chao Liu authored Feb 25, 2022
1 parent bdedf64 commit e221d11
Show file tree
Hide file tree
Showing 11 changed files with 1,713 additions and 30 deletions.

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion device_operation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
)
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
)

# device_gemm_bias_2d_instance
set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
Expand Down
4 changes: 2 additions & 2 deletions device_operation/include/conv_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ std::size_t GetFlops(ck::index_t N,
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()) *
std::multiplies<std::size_t>()) *
C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
std::multiplies<std::size_t>());
}

/**
Expand Down
Loading

0 comments on commit e221d11

Please sign in to comment.