From 1d8e4ec2ced2da813947e89654f69f7bf6b5079e Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 2 Jan 2025 04:48:06 +0100 Subject: [PATCH] Jing's contribution: prototype of mixed precision gemm FP16/BF16xint4 GEMM (#1762) * add a prototype of int4 * clean * debug * clean * clean * move packed into dynamic_buffer * fixed coord reset * add fast pki4 to half conversion * fix * fixed reference and host_tensor * fixed tensor init * format * debug i4_to_f16_convert * format * fixed splitk * weight permute * add b tile permute * clean * weight permute with splitki * format * improve weight layout * add and_or_b32 * fixed splitk crush * add permute switch as a template * recover v3r1 * clean * failure with intrawave v2 * fixed * fixed * add ckProfiler * add bfp16 support * add bf16 example * fixed int4 to bhalf_t conversion * format * fixed int4 to bf16 conversion * clean * add instances for mem * clean * fixed host tensor size * fixed * debug * fixed * add pk_i4_t as a struct * fix * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * revert * Update example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed comments * revert * clean * revert * revert * fixed * Update CMakeLists.txt * Update script/cmake-ck-dev.sh Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Update CMakeLists.txt Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * fixed * fixed * fixed * revert * revert * add comments * format * fixed assert * fixed * Fix I4 define in ckProfiler * Fixed example_gemm_xdl_bf16_pk_i4_v3 test failed issue --------- Co-authored-by: Jing Zhang Co-authored-by: zjing14 Co-authored-by: mtgu0705 --- CMakeLists.txt | 2 +- cmake/EnableCompilerWarnings.cmake | 2 +- example/01_gemm/CMakeLists.txt | 2 + example/01_gemm/common.hpp | 82 +++++ example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp | 253 +++++++++++++++ example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp | 16 +- example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp | 303 ++++++++++++++++++ example/01_gemm/gemm_xdl_fp16_v3.cpp | 20 +- example/01_gemm/run_gemm_example.inc | 82 ----- .../01_gemm/run_gemm_example_streamk_v2.inc | 82 ----- example/01_gemm/run_gemm_example_v2.inc | 82 ----- include/ck/library/utility/host_tensor.hpp | 65 +++- .../library/utility/host_tensor_generator.hpp | 30 ++ include/ck/tensor/static_tensor.hpp | 4 +- .../gpu/device/device_gemm_v2.hpp | 4 + .../impl/device_gemm_xdl_cshuffle_v3.hpp | 13 +- .../element/unary_element_wise_operation.hpp | 189 +++++++++++ .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 104 ++++-- .../threadwise_tensor_slice_transfer.hpp | 50 ++- .../threadwise_tensor_slice_transfer_v3r1.hpp | 74 +++-- include/ck/utility/amd_buffer_addressing.hpp | 3 +- include/ck/utility/amd_inline_asm.hpp | 21 ++ include/ck/utility/data_type.hpp | 35 ++ include/ck/utility/dynamic_buffer.hpp | 6 +- include/ck/utility/static_buffer.hpp | 6 +- .../cpu/reference_gemm.hpp | 22 ++ .../device_operation_instance_factory.hpp | 1 + .../gpu/gemm_universal.hpp | 33 ++ .../gpu/gemm_universal/CMakeLists.txt | 3 + ...mm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp | 87 +++++ ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 24 ++ ...gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp | 86 +++++ ...4_f16_mk_nk_mn_mem_v2_default_instance.cpp | 24 ++ .../profiler/profile_gemm_universal_impl.hpp | 103 +++++- profiler/src/CMakeLists.txt | 1 - profiler/src/profile_gemm_universal.cpp | 17 +- script/cmake-ck-dev.sh | 2 +- 37 files changed, 1583 insertions(+), 350 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp create mode 100644 example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index be4efd3dfd..6d4176735f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -585,7 +585,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) ) add_subdirectory(example) if(BUILD_TESTING) - add_subdirectory(test) + add_subdirectory(test) endif() endif() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 93fd306e98..fb2b38d688 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 957acce165..df7be0466e 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -29,6 +29,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) +add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) +add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index a3a62d4cfa..9664c50b6e 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -287,3 +287,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000..7b491173a6 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp index 2e27fc66f9..b0e36b394b 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" -using ADataType = ck::f8_t; -using BDataType = ck::half_t; +using ADataType = ck::half_t; +using BDataType = ck::f8_t; using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; @@ -29,15 +29,15 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, - 64, 16, 8, + 256, 8, 16, 16, 16, 1, 1, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 16, 16, 0, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp index ad370f570e..4a969246cd 100644 --- a/example/01_gemm/gemm_xdl_fp16_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -27,17 +27,17 @@ using DeviceGemmV2Instance = ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 256, - 224, 256, - 64, 8, 2, + 64, + 16, 16, + 256, 8, 8, 16, 16, - 7, 8, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 8, 2, 0, - 1, 2, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3ee6e26856..4371af6244 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -5,88 +5,6 @@ #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 04243b8291..9ee380d247 100755 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 5b6969f1d9..2b60fa5d28 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 18e1db462a..ef5738be08 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -266,18 +266,18 @@ struct Tensor using Data = std::vector; template - Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(std::initializer_list lens) : mDesc(lens), mData(GetElementSpaceSize()) { } template Tensor(std::initializer_list lens, std::initializer_list strides) - : mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) + : mDesc(lens, strides), mData(GetElementSpaceSize()) { } template - Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize()) { } @@ -287,7 +287,7 @@ struct Tensor { } - Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} + Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} template Tensor CopyAsType() const @@ -322,7 +322,17 @@ struct Tensor std::size_t GetElementSize() const { return mDesc.GetElementSize(); } - std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } + std::size_t GetElementSpaceSize() const + { + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return (mDesc.GetElementSpaceSize() + 1) / 2; + } + else + { + return mDesc.GetElementSpaceSize(); + } + } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } @@ -469,29 +479,64 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mDesc.GetOffsetFromMultiIndex(is...) / 2; + } + else + { + return mDesc.GetOffsetFromMultiIndex(is...); + } } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } T& operator()(std::vector idx) { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } typename Data::iterator begin() { return mData.begin(); } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index ab9f01b53c..6a90523c33 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -81,6 +81,20 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int t = value + 8; + ck::pk_i4_t r = ((t << 4) + t) & 0xff; + return r; + } +}; + template struct GeneratorTensor_2 { @@ -121,6 +135,22 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int hi = std::rand() % (max_value - min_value) + min_value + 8; + int lo = std::rand() % (max_value - min_value) + min_value + 8; + ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; + return r; + } +}; + #if defined CK_ENABLE_FP8 template <> struct GeneratorTensor_2 diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index d719ef9760..ef2bedd65c 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr X GetAsType(Idx) const @@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr void SetAsType(Idx, X x) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp index b2db35b159..43909f77d3 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp @@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteA() = 0; + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; }; template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; + ComputeTypeB, + PermuteA, + PermuteB>; using Argument = typename GridwiseGemm::Argument; @@ -633,6 +637,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(p_arg)); } + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 39b81ca573..86a5af41bc 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -7,12 +7,177 @@ #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/amd_inline_asm.hpp" #include namespace ck { + +// Fast int4x4 to half8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__host__ __device__ inline half4_t pki4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + return res.template AsType()[Number<0>{}]; +} + +__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) +{ +#if 1 + uint8_t x_u8 = ck::bit_cast(q); + uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); + + const int EX = 0x64006400; + const int SUB = 0xE408E408; //-8 + + int lo = i4s | EX; + + return amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); +#else + uint8_t x_u8 = ck::bit_cast(q); + + vector_type res; + + half_t x_h = (x_u8 & 0x0f) - 8; + half_t x_l = ((x_u8 & 0xf0) >> 4) - 8; + + res.template AsType()(Number<0>{}) = x_l; + res.template AsType()(Number<1>{}) = x_h; + + return res.template AsType()[Number<0>{}]; +#endif +} + +__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + vector_type res; + res.template AsType()(Number<0>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.template AsType()(Number<1>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res.template AsType()[Number<0>{}]; +} + +__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q) +{ + uint8_t x_u8 = ck::bit_cast(q); + + float x_h = ((x_u8 & 0x0f) >> 0) - 8.f; + float x_l = ((x_u8 & 0xf0) >> 4) - 8.f; + + vector_type res; + + res.template AsType()(Number<0>{}) = type_convert(x_l); + res.template AsType()(Number<1>{}) = type_convert(x_h); + + return res.template AsType()[Number<0>{}]; +} + namespace tensor_operation { namespace element_wise { +struct PassThroughPack8 +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const + { +#if 1 + vector_type result; + + result.template AsType()(Number<0>{}) = pki4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = pki4_to_half4(bit_cast(x) >> 8); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + pki4_to_half2(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + pki4_to_half2(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + pki4_to_half2(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + pki4_to_half2(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const + { +#if 1 + vector_type result; + + result.template AsType()(Number<0>{}) = pki4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = pki4_to_bhalf4(bit_cast(x) >> 16); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + pki4_to_bhalf2(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + pki4_to_bhalf2(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + pki4_to_bhalf2(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + pki4_to_bhalf2(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + constexpr const static bool is_pack8_invocable = true; +}; + #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wnon-virtual-dtor" struct UnaryOpBase @@ -49,6 +214,24 @@ struct PassThroughPack2 auto t = type_convert(x); y = type_convert(t); } + + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const + { +#if 1 + uint8_t x_u8 = ck::bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; + + auto l_f16 = ck::type_convert(x_l); + auto h_f16 = ck::type_convert(x_h); + + y = {l_f16, h_f16}; +#else + uint32_t t = ck::bit_cast(x); + y = ck::bit_cast(t); +#endif + } + constexpr const static bool is_pack2_invocable = true; }; @@ -76,6 +259,12 @@ struct PassThrough final : public UnaryOpBase template __host__ __device__ void operator()(Y& y, const X& x) const; + template <> + __host__ __device__ void operator()(pk_i4_t& y, const pk_i4_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(float& y, const double& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 36797a906a..a43f0f880a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -127,7 +127,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct GridwiseGemm_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThisThreadBlock = ThisThreadBlock; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 using GemmSpecialization = tensor_operation::device::GemmSpecialization; + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) { @@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } } } @@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead; + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; } else if constexpr(is_same_v) { @@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead; + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } } if(blockIdx.z < static_cast(karg.KBatch - 1)) @@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( AK0Number * Number{}, Number{}, AK1Number), @@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( BK0Number * Number{}, Number{}, BK1Number), @@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } @@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_buf_pong = make_dynamic_buffer( static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index d7a6a36244..7589002003 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { @@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } } template src_tmp_vector; + vector_type_maker_t src_tmp_vector; using src_vector_t = typename decltype(src_tmp_vector)::type; @@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4 if constexpr(SrcBuffer::IsDynamicBuffer()) { src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); } else if constexpr(SrcBuffer::IsStaticBuffer()) { @@ -1133,9 +1146,36 @@ struct ThreadwiseTensorSliceTransfer_v4 }); } - if constexpr(is_same, f8_t>::value && - is_same, half_t>::value && - SrcScalarPerVector % 2 == 0) + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) { // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // DstData) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 96ea04c8fa..8cbe6bd2c5 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -31,8 +31,8 @@ template {}; + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr auto SrcScalarPerVector = Number{}; + static constexpr auto DstScalarPerVector = Number{}; + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); + } } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; @@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; - auto src_vector_container = - src_vector_type{src_buf.template Get(src_coord_.GetOffset(), true)}; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; dst_vector_type op_r_v; @@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(decltype(src_element_op_)::is_pack8_invocable) return math::min(8, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack4_invocable) return math::min(4, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack2_invocable) return math::min(2, SrcScalarPerVector); } - return 1; + else + { + return 1; + } }; constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); @@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_elem_op_vec_t = typename vector_type::type; using dst_elem_op_vec_t = typename vector_type::type; + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset() / PackedSize, true)}; + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { // apply the src elementwise op and convert to DstData under the hood if needed src_element_op_(op_r_v.template AsType()(idx), @@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else - // OOB Check constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { + static_assert(!is_same_v, pk_i4_t>, + "in-register transpose is not supported for pk_i4_t"); // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ @@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } else { - static_ford{}([&](auto idx) { + constexpr auto packed_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; + + static_ford{}([&](auto idx) { dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } @@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // apply DstElementwiseOperation dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); - - dst_vector_container.template AsType()(i) = dst_v; }); // copy data from dst_vector_container to dst_buf dst_buf.template Set( - dst_coord_.GetOffset(), + dst_coord_.GetOffset() / PackedSize, is_dst_valid, dst_vector_container.template AsType()[I0]); @@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { // 1st stage of transforms constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 5367c3d720..ad13c44311 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -429,7 +429,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 5dc67a5ade..6761c08f2b 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -11,6 +11,27 @@ namespace ck { +inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) +{ + int c; + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d)); + return c; +} + +inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) +{ + half2_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +} + +inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) +{ + half2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a7dc071bc2..86bc3c394e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -12,6 +12,15 @@ using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); +// custom data type - pack int4 data +struct pk_i4_t +{ + using type = int8_t; + type data; + __host__ __device__ constexpr pk_i4_t() : data{type{}} {} + __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} +}; + inline constexpr auto next_pow2(uint32_t x) { // Precondition: x > 1. @@ -165,6 +174,13 @@ struct scalar_type }; #endif +template <> +struct scalar_type +{ + using type = pk_i4_t; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { @@ -1044,6 +1060,12 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + template struct non_native_vector_base< T, @@ -1163,6 +1185,14 @@ struct scalar_type> static constexpr index_t vector_size = N; }; +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + // non-native vector_type implementation template struct vector_type()>> @@ -1871,6 +1901,11 @@ using uint8x16_t = typename vector_type::type; using uint8x32_t = typename vector_type::type; using uint8x64_t = typename vector_type::type; +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + template struct NumericLimits { diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 0dcc514a2f..639aa1efe4 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -54,7 +54,8 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { @@ -195,7 +196,8 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 835f565730..602e76abdb 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector // i is offset of S, not X. i should be aligned to X template ::value, bool>::type = false> + typename enable_if::value || !is_native_type(), + bool>::type = false> __host__ __device__ constexpr auto GetAsType(Number i) const { constexpr auto s_per_x = Number>::vector_size>{}; @@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector // i is offset of S, not X. i should be aligned to X template ::value, bool>::type = false> + typename enable_if::value || !is_native_type(), + bool>::type = false> __host__ __device__ constexpr void SetAsType(Number i, X x) { constexpr auto s_per_x = Number>::vector_size>{}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 1ae11fe9db..8dd5d086b3 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator { ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); } + else if constexpr(is_same_v) + { + uint8_t i4x2 = arg.a_m_k_(m, k).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; + i4 = i4 - 8; + v_a = type_convert(i4); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -84,6 +95,17 @@ struct ReferenceGemm : public device::BaseOperator { ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); } + else if constexpr(is_same_v) + { + uint8_t i4x2 = arg.b_k_n_(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; + i4 = i4 - 8; + v_b = type_convert(i4); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 4358953a58..4a44c425aa 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -22,6 +22,7 @@ using I8 = int8_t; using I32 = int32_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; +using I4 = ck::pk_i4_t; using Empty_Tuple = ck::Tuple<>; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp index 52046a1070..4218c51ca3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp @@ -166,11 +166,22 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances std::vector>>& instances); + void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( std::vector>>& instances); +void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( std::vector>>& @@ -810,6 +821,28 @@ struct DeviceOperationInstanceFactory< } } #endif + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); + } + } + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + } + } + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index 188c9f68e3..ade65eacf3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -97,6 +97,9 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..8d109d1346 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +#if 0 +template +using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; +#endif + +template +using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..b060a92eb6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..680788d668 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +#if 0 +template +using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; +#endif + +template +using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..a884a3ec56 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 30f0da212d..ed7e86ded8 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -65,11 +65,13 @@ bool profile_gemm_universal_impl(int do_verification, Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); - int rotating_count = std::max( + std::size_t total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( 1, std::min(n_iter, static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); @@ -86,9 +88,13 @@ bool profile_gemm_universal_impl(int do_verification, a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; - default: + case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); } using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -100,11 +106,10 @@ bool profile_gemm_universal_impl(int do_verification, const auto c_element_op = CElementOp{}; DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); using DeviceOp = ck::tensor_operation::device::DeviceGemmV2GetKPerBlock(); + + if(op_ptr->GetPermuteB()) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + + if(is_same_v && is_same_v) + { + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i); + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + } + } + else + { + b_k_n_permute = b_k_n; + } + + b_device_buf.ToDevice(b_k_n_permute.mData.data()); + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; if(KBatch > 0) @@ -240,7 +323,15 @@ bool profile_gemm_universal_impl(int do_verification, std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + std::size_t num_btype = sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / BPackedSize + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 35e91f8172..a0978eb6bf 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -177,5 +177,4 @@ if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) endif() - rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index 990cbd292e..a22d983da5 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include #include #include -#include -#include #include "profiler/profile_gemm_universal_impl.hpp" #include "profiler_operation_registry.hpp" @@ -27,6 +27,8 @@ enum struct GemmDataType F16_F8_F16, // 5 F16_F16_F16_F8, // 6 F8_F8_BF16, // 7 + F16_I4_F16, // 8 + BF16_I4_BF16, // 9 }; #define OP_NAME "gemm_universal" @@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[]) printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " "f16->f8; 7: f8->bf16, " - "comp f8)\n"); + "comp f8; 8: f16@i4; 9: bf16@i4\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); @@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[]) using BF16 = ck::bhalf_t; #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using F8 = ck::f8_t; + using I4 = ck::pk_i4_t; #endif using Row = ck::tensor_layout::gemm::RowMajor; @@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } + else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{}); + } #endif else { diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 4097ca98f6..f7177a7abb 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \