-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Python] Add batched gemm instances parsing (#1684)
* add op * do not insert ds parameters as they are already parsed * reset ds parameters * apply ruff
- Loading branch information
1 parent
cff7fab
commit 44828b7
Showing
3 changed files
with
249 additions
and
3 deletions.
There are no files selected for viewing
149 changes: 149 additions & 0 deletions
149
python/ck4inductor/batched_universal_gemm/gen_instances.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
import logging | ||
import os | ||
import subprocess | ||
from dataclasses import replace | ||
from functools import lru_cache | ||
from typing import List | ||
|
||
from ..util import library_path | ||
|
||
from .op import CKBatchedGemmOperation | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _ck_library_dir(): | ||
gemm_instances_path = os.path.join( | ||
library_path(), | ||
"src", | ||
"tensor_operation_instance", | ||
"gpu", | ||
"gemm_universal_batched", | ||
) | ||
if not os.path.exists(gemm_instances_path): | ||
log.error("CK library path %s does not exist", gemm_instances_path) | ||
return None | ||
return gemm_instances_path | ||
|
||
|
||
def parse_instances(str_instances: List[str]) -> List[CKBatchedGemmOperation]: | ||
""" | ||
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances | ||
""" | ||
|
||
def maybe_int(s): | ||
try: | ||
return int(s) | ||
except ValueError: | ||
return s | ||
|
||
op_instances = [] | ||
for line in str_instances: | ||
s_template_args = line.split("DeviceBatchedGemmMultiD_Xdl_CShuffle_V3")[ | ||
-1 | ||
].strip("<>, ") | ||
template_args = [] | ||
i_current = 0 | ||
while i_current < len(s_template_args): | ||
if s_template_args[i_current] == " ": | ||
# skip whitespace | ||
i_current += 1 | ||
continue | ||
elif s_template_args[i_current : i_current + 2] == "S<": | ||
# parse template S<Index...> | ||
i_next = s_template_args.find(">", i_current) | ||
template_args.append( | ||
tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) | ||
) | ||
i_current = i_next + 2 | ||
else: | ||
# all string attributes must be either type aliases or global constants in C++ | ||
i_next = s_template_args.find(",", i_current) | ||
template_args.append( | ||
maybe_int( | ||
s_template_args[i_current : i_next if i_next != -1 else None] | ||
) | ||
) | ||
if i_next != -1: | ||
i_current = i_next + 1 | ||
if i_next == -1: | ||
break | ||
|
||
# ds layout and dtype are parsed as placeholder; reset value | ||
template_args[2] = tuple() # ds layout | ||
template_args[6] = tuple() # ds dtype | ||
|
||
new_instance = CKBatchedGemmOperation( | ||
*template_args, # type: ignore[arg-type] | ||
) | ||
|
||
op_instances.append(new_instance) | ||
return op_instances | ||
|
||
|
||
@lru_cache(None) | ||
def gen_ops_library() -> List[CKBatchedGemmOperation]: | ||
""" | ||
Parse the Universal Gemm instances defined in the composable kernel library folder. | ||
""" | ||
ck_library_dir = _ck_library_dir() | ||
if not ck_library_dir: | ||
return [] | ||
|
||
grep_result = subprocess.run( | ||
[ | ||
"grep", | ||
"-inR", | ||
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3", | ||
_ck_library_dir(), | ||
], | ||
capture_output=True, | ||
text=True, | ||
) | ||
|
||
op_instances = parse_instances(grep_result.stdout.strip().split("\n")) | ||
|
||
log.debug("ck instances from library: %d", len(op_instances)) | ||
|
||
schedulers = [ | ||
"BlockGemmPipelineScheduler::Intrawave", | ||
"BlockGemmPipelineScheduler::Interwave", | ||
] | ||
gemm_specs = [ | ||
"GemmSpecialization::Default", | ||
"GemmSpecialization::MPadding", | ||
"GemmSpecialization::NPadding", | ||
"GemmSpecialization::KPadding", | ||
"GemmSpecialization::MNPadding", | ||
"GemmSpecialization::MKPadding", | ||
"GemmSpecialization::NKPadding", | ||
"GemmSpecialization::MNKPadding", | ||
] | ||
|
||
# substitute templated args by looping through their domains | ||
substitute_instances = [] | ||
for instance in op_instances: | ||
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" | ||
sub_spec = instance.gemm_specialization == "GemmSpec" | ||
schedulers_range = ( | ||
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] | ||
) | ||
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] | ||
for scheduler in schedulers_range: | ||
for spec in spec_range: | ||
substitute_instances.append( | ||
replace( | ||
instance, | ||
block_gemm_pipeline_scheduler=scheduler, | ||
gemm_specialization=spec, | ||
) | ||
) | ||
|
||
return substitute_instances | ||
|
||
|
||
if __name__ == "__main__": | ||
print(gen_ops_library()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
from dataclasses import asdict, dataclass | ||
from typing import Optional, Tuple | ||
|
||
|
||
@dataclass | ||
class CKBatchedGemmOperation: | ||
""" | ||
A python dataclass storing the template parameters of a CK Universal Gemm template instance | ||
""" | ||
|
||
a_layout: str | ||
b_layout: str | ||
ds_layouts: Tuple[str] # addmm specific | ||
c_layout: str | ||
|
||
a_element_dtype: str | ||
b_element_dtype: str | ||
ds_element_dtypes: Tuple[str] # addmm specific | ||
c_element_dtype: str | ||
|
||
acc_dtype: str | ||
c_shuffle_dtype: str | ||
|
||
a_elementwise_op: str | ||
b_elementwise_op: str | ||
c_elementwise_op: str | ||
|
||
gemm_specialization: str | ||
|
||
block_size: int | ||
|
||
m_per_block: int | ||
n_per_block: int | ||
k_per_block: int | ||
|
||
a_k1: int | ||
b_k1: int | ||
|
||
m_per_xdl: int | ||
n_per_xdl: int | ||
|
||
m_xdl_per_wave: int | ||
n_xdl_per_wave: int | ||
|
||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] | ||
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] | ||
a_block_transfer_src_access_order: Tuple[int, int, int] | ||
a_block_transfer_src_vector_dim: int | ||
a_block_transfer_src_scalar_per_vector: int | ||
a_block_transfer_dst_scalar_per_vector_ak1: int | ||
a_block_lds_extra_m: bool | ||
|
||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] | ||
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] | ||
b_block_transfer_src_access_order: Tuple[int, int, int] | ||
|
||
b_block_transfer_src_vector_dim: int | ||
b_block_transfer_src_scalar_per_vector: int | ||
b_block_transfer_dst_scalar_per_vector_bk1: int | ||
b_block_lds_extra_n: bool | ||
|
||
c_shuffle_m_xdl_per_wave_per_shuffle: int | ||
c_shuffle_n_xdl_per_wave_per_shuffle: int | ||
|
||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( | ||
Tuple[int, int, int, int] | ||
) | ||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: Tuple[int] | ||
block_gemm_pipeline_scheduler: str | ||
block_gemm_pipeline_version: str | ||
|
||
a_compute_dtype: Optional[str] = None | ||
b_compute_dtype: Optional[str] = None | ||
|
||
def name(self): | ||
# cpp alias for template instance | ||
return f"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_{self.key_name()}" | ||
|
||
def key_name(self): | ||
# TBD; must be unique per instance. Intended to use as dict key | ||
return "_".join( | ||
[ | ||
"K" | ||
+ field_name.replace("_", "").lower() | ||
+ "V" | ||
+ ( | ||
"x".join(map(str, iter(field_value))) | ||
if isinstance(field_value, tuple) | ||
else str(field_value).replace(":", "") | ||
) | ||
for field_name, field_value in self.dict_items() | ||
] | ||
) | ||
|
||
def dict_items(self): | ||
return asdict(self).items() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters