From 29e58d5b28a7f8490ced9b25c17519d110f7bba7 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 21 May 2024 20:37:26 +0000 Subject: [PATCH] Make the library which generates CK instances for pytorch2 inductor's CK backend usage Also bundle the CK library and include files with the pip package. The package is pip-installable with `pip install git+https://github.com/tenpercent/composable_kernel@enable-pip` (substitute the repo path and branch if necessary) Testing: `myenv/bin/python3 -m ck4inductor.universal_gemm.gen_instances` (prints a list of instances) `tree myenv/lib/python3.12/site-packages/ck4inductor` (observe the list of sources along the installed package) --- pyproject.toml | 36 ++ python/ck4inductor/__init__.py | 0 .../universal_gemm/gen_instances.py | 570 ++++++++++++++++++ python/ck4inductor/universal_gemm/op.py | 95 +++ python/ck4inductor/util.py | 7 + 5 files changed, 708 insertions(+) create mode 100644 pyproject.toml create mode 100644 python/ck4inductor/__init__.py create mode 100644 python/ck4inductor/universal_gemm/gen_instances.py create mode 100644 python/ck4inductor/universal_gemm/op.py create mode 100644 python/ck4inductor/util.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..8e7e8607ba --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "rocm-composable-kernel" +dynamic = ["version"] +description = "Composable Kernel, performance-critical kernels for machine learning workloads" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/rocm/composable_kernel" +"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" + +[tool.setuptools] +packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"] + +[tool.setuptools.package-dir] +ck4inductor = "python/ck4inductor" +"ck4inductor.include" = "include" +"ck4inductor.library" = "library" + +[tool.setuptools.package-data] +"ck4inductor.include" = ["ck/**/*.hpp"] +"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"] + +[tool.setuptools.dynamic] +version = { attr = "setuptools_scm.get_version" } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py new file mode 100644 index 0000000000..8b6d6b73b2 --- /dev/null +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -0,0 +1,570 @@ +import logging +import os +import subprocess +from dataclasses import fields, replace +from functools import lru_cache, partial +from typing import List + +from ..util import library_path + +from .op import CKGemmOperation + +log = logging.getLogger(__name__) + + +def _ck_library_dir(): + gemm_instances_path = os.path.join( + library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal" + ) + if not os.path.exists(gemm_instances_path): + log.error("CK library path %s does not exist", gemm_instances_path) + return None + return gemm_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: + """ + Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + for line in str_instances: + s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + # pad with `None`s for the fields which are not defined in the instance + new_instance = CKGemmOperation( + *template_args, # type: ignore[arg-type] + *((None,) * (len(fields(CKGemmOperation)) - len(template_args))), + ) + # the last 2 template parameters are optional + # if they are absent, substitute them with default values from Universal Gemm C++ template declaration + if new_instance.a_compute_dtype is None: + new_instance.a_compute_dtype = new_instance.c_element_dtype + if new_instance.b_compute_dtype is None: + new_instance.b_compute_dtype = new_instance.c_element_dtype + + op_instances.append(new_instance) + return op_instances + + +def default_instances() -> List[CKGemmOperation]: + # fallback: known working op instance for problem size M=2240 K=256 N=2048 + # all string attributes must be either type aliases or global constants in C++ + + return [ + CKGemmOperation( + a_layout="Row", + b_layout="Row", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + a_compute_dtype="F16", + b_compute_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + gemm_specialization="GemmSpecialization::Default", + block_size=256, + m_per_block=224, + n_per_block=256, + k_per_block=64, + a_k1=8, + b_k1=2, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, # type: ignore[arg-type] + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + b_block_transfer_thread_cluster_arrange_order=(0, 2, 1), + b_block_transfer_src_access_order=(0, 2, 1), + b_block_transfer_src_vector_dim=1, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=2, + b_block_lds_extra_n=0, # type: ignore[arg-type] + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ) + ] + + +@lru_cache(None) +def gen_ops_library() -> List[CKGemmOperation]: + """ + Parse the Universal Gemm instances defined in the composable kernel library folder. + """ + ck_library_dir = _ck_library_dir() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceGemm_Xdl_CShuffleV3", + _ck_library_dir(), + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + gemm_specs = [ + "GemmSpecialization::Default", + "GemmSpecialization::MPadding", + "GemmSpecialization::NPadding", + "GemmSpecialization::KPadding", + "GemmSpecialization::MNPadding", + "GemmSpecialization::MKPadding", + "GemmSpecialization::NKPadding", + "GemmSpecialization::MNKPadding", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + sub_spec = instance.gemm_specialization == "GemmSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + gemm_specialization=spec, + ) + ) + + return substitute_instances + + +@lru_cache(None) +def gen_ops_preselected() -> List[CKGemmOperation]: + """ + Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances + """ + ck_gemm_f16_rcr = partial( + CKGemmOperation, + a_layout="Row", + b_layout="Col", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + k_per_block=64, + a_k1=8, + b_k1=8, + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, + b_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + b_block_transfer_src_access_order=(1, 0, 2), + b_block_transfer_src_vector_dim=2, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=8, + b_block_lds_extra_n=0, + a_compute_dtype="F16", + b_compute_dtype="F16", + ) + ck_gemm_f16_rcr_compute_friendly = partial( + ck_gemm_f16_rcr, + block_size=256, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ) + ck_gemm_f16_rcr_memory_friendly = partial( + ck_gemm_f16_rcr, + block_size=128, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v2", + ) + ck_gemm_f16_rcr_latency_friendly = partial( + ck_gemm_f16_rcr, + gemm_specialization="GemmSpecialization::Default", + block_size=128, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v1", + ) + return [ + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=224, + n_per_block=256, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=64, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=64, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 64, + 1, + 2, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=16, + n_per_block=32, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=32, + n_per_block=16, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + ), + ] + + +if __name__ == "__main__": + print(gen_ops_library()) diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py new file mode 100644 index 0000000000..ab541c5fb9 --- /dev/null +++ b/python/ck4inductor/universal_gemm/op.py @@ -0,0 +1,95 @@ +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKGemmOperation: + """ + A python dataclass storing the template parameters of a CK Universal Gemm template instance + """ + + a_layout: str + b_layout: str + c_layout: str + + a_element_dtype: str + b_element_dtype: str + c_element_dtype: str + + acc_dtype: str + c_shuffle_dtype: str + + a_elementwise_op: str + b_elementwise_op: str + c_elementwise_op: str + + gemm_specialization: str + + block_size: int + + m_per_block: int + n_per_block: int + k_per_block: int + + a_k1: int + b_k1: int + + m_per_xdl: int + n_per_xdl: int + + m_xdl_per_wave: int + n_xdl_per_wave: int + + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( + Tuple[int, int, int, int] + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block: int + + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: Optional[str] + + a_compute_dtype: Optional[str] + b_compute_dtype: Optional[str] + + def name(self): + # cpp alias for template instance + return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/util.py b/python/ck4inductor/util.py new file mode 100644 index 0000000000..79d6be00f3 --- /dev/null +++ b/python/ck4inductor/util.py @@ -0,0 +1,7 @@ +import functools +import os + + +@functools.lru_cache(None) +def library_path(): + return os.path.join(os.path.dirname(__file__), 'library')